MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Matchers.h"
24#include "mlir/IR/SymbolTable.h"
26
27#include "llvm/ADT/ArrayRef.h"
28#include "llvm/ADT/PostOrderIterator.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/STLForwardCompat.h"
31#include "llvm/ADT/SmallString.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/ADT/StringRef.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/ADT/bit.h"
36#include "llvm/Support/InterleavedRange.h"
37#include <cstddef>
38#include <iterator>
39#include <optional>
40#include <variant>
41
42#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
46
47using namespace mlir;
48using namespace mlir::omp;
49
52 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
53}
54
57 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
58}
59
62 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
63}
64
65namespace {
66struct MemRefPointerLikeModel
67 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
68 MemRefType> {
69 Type getElementType(Type pointer) const {
70 return llvm::cast<MemRefType>(pointer).getElementType();
71 }
72};
73
74struct LLVMPointerPointerLikeModel
75 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
77 Type getElementType(Type pointer) const { return Type(); }
78};
79} // namespace
80
81/// Generate a name of a canonical loop nest of the format
82/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
83/// argument index of an operation that has multiple regions, if the operation
84/// has multiple regions.
85/// `_s<idx>` identifies the position of an operation within a region, where
86/// only operations that may potentially contain loops ("container operations"
87/// i.e. have region arguments) are counted. Again, it is omitted if there is
88/// only one such operation in a region. If there are canonical loops nested
89/// inside each other, also may also use the format `_d<num>` where <num> is the
90/// nesting depth of the loop.
91///
92/// The generated name is a best-effort to make canonical loop unique within an
93/// SSA namespace. This also means that regions with IsolatedFromAbove property
94/// do not consider any parents or siblings.
95static std::string generateLoopNestingName(StringRef prefix,
96 CanonicalLoopOp op) {
97 struct Component {
98 /// If true, this component describes a region operand of an operation (the
99 /// operand's owner) If false, this component describes an operation located
100 /// in a parent region
101 bool isRegionArgOfOp;
102 bool skip = false;
103 bool isUnique = false;
104
105 size_t idx;
106 Operation *op;
107 Region *parentRegion;
108 size_t loopDepth;
109
110 Operation *&getOwnerOp() {
111 assert(isRegionArgOfOp && "Must describe a region operand");
112 return op;
113 }
114 size_t &getArgIdx() {
115 assert(isRegionArgOfOp && "Must describe a region operand");
116 return idx;
117 }
118
119 Operation *&getContainerOp() {
120 assert(!isRegionArgOfOp && "Must describe a operation of a region");
121 return op;
122 }
123 size_t &getOpPos() {
124 assert(!isRegionArgOfOp && "Must describe a operation of a region");
125 return idx;
126 }
127 bool isLoopOp() const {
128 assert(!isRegionArgOfOp && "Must describe a operation of a region");
129 return isa<CanonicalLoopOp>(op);
130 }
131 Region *&getParentRegion() {
132 assert(!isRegionArgOfOp && "Must describe a operation of a region");
133 return parentRegion;
134 }
135 size_t &getLoopDepth() {
136 assert(!isRegionArgOfOp && "Must describe a operation of a region");
137 return loopDepth;
138 }
139
140 void skipIf(bool v = true) { skip = skip || v; }
141 };
142
143 // List of ancestors, from inner to outer.
144 // Alternates between
145 // * region argument of an operation
146 // * operation within a region
147 SmallVector<Component> components;
148
149 // Gather a list of parent regions and operations, and the position within
150 // their parent
151 Operation *o = op.getOperation();
152 while (o) {
153 // Operation within a region
154 Region *r = o->getParentRegion();
155 if (!r)
156 break;
157
158 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
159 size_t idx = 0;
160 bool found = false;
161 size_t sequentialIdx = -1;
162 bool isOnlyContainerOp = true;
163 for (Block *b : traversal) {
164 for (Operation &op : *b) {
165 if (&op == o && !found) {
166 sequentialIdx = idx;
167 found = true;
168 }
169 if (op.getNumRegions()) {
170 idx += 1;
171 if (idx > 1)
172 isOnlyContainerOp = false;
173 }
174 if (found && !isOnlyContainerOp)
175 break;
176 }
177 }
178
179 Component &containerOpInRegion = components.emplace_back();
180 containerOpInRegion.isRegionArgOfOp = false;
181 containerOpInRegion.isUnique = isOnlyContainerOp;
182 containerOpInRegion.getContainerOp() = o;
183 containerOpInRegion.getOpPos() = sequentialIdx;
184 containerOpInRegion.getParentRegion() = r;
185
186 Operation *parent = r->getParentOp();
187
188 // Region argument of an operation
189 Component &regionArgOfOperation = components.emplace_back();
190 regionArgOfOperation.isRegionArgOfOp = true;
191 regionArgOfOperation.isUnique = true;
192 regionArgOfOperation.getArgIdx() = 0;
193 regionArgOfOperation.getOwnerOp() = parent;
194
195 // The IsolatedFromAbove trait of the parent operation implies that each
196 // individual region argument has its own separate namespace, so no
197 // ambiguity.
198 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
199 break;
200
201 // Component only needed if operation has multiple region operands. Region
202 // arguments may be optional, but we currently do not consider this.
203 if (parent->getRegions().size() > 1) {
204 auto getRegionIndex = [](Operation *o, Region *r) {
205 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
206 if (&region == r)
207 return idx;
208 }
209 llvm_unreachable("Region not child of its parent operation");
210 };
211 regionArgOfOperation.isUnique = false;
212 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
213 }
214
215 // next parent
216 o = parent;
217 }
218
219 // Determine whether a region-argument component is not needed
220 for (Component &c : components)
221 c.skipIf(c.isRegionArgOfOp && c.isUnique);
222
223 // Find runs of nested loops and determine each loop's depth in the loop nest
224 size_t numSurroundingLoops = 0;
225 for (Component &c : llvm::reverse(components)) {
226 if (c.skip)
227 continue;
228
229 // non-skipped multi-argument operands interrupt the loop nest
230 if (c.isRegionArgOfOp) {
231 numSurroundingLoops = 0;
232 continue;
233 }
234
235 // Multiple loops in a region means each of them is the outermost loop of a
236 // new loop nest
237 if (!c.isUnique)
238 numSurroundingLoops = 0;
239
240 c.getLoopDepth() = numSurroundingLoops;
241
242 // Next loop is surrounded by one more loop
243 if (isa<CanonicalLoopOp>(c.getContainerOp()))
244 numSurroundingLoops += 1;
245 }
246
247 // In loop nests, skip all but the innermost loop that contains the depth
248 // number
249 bool isLoopNest = false;
250 for (Component &c : components) {
251 if (c.skip || c.isRegionArgOfOp)
252 continue;
253
254 if (!isLoopNest && c.getLoopDepth() >= 1) {
255 // Innermost loop of a loop nest of at least two loops
256 isLoopNest = true;
257 } else if (isLoopNest) {
258 // Non-innermost loop of a loop nest
259 c.skipIf(c.isUnique);
260
261 // If there is no surrounding loop left, this must have been the outermost
262 // loop; leave loop-nest mode for the next iteration
263 if (c.getLoopDepth() == 0)
264 isLoopNest = false;
265 }
266 }
267
268 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
269 // we mark them as skipped only after computing loop nests)
270 for (Component &c : components)
271 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
272 !isa<CanonicalLoopOp>(c.getContainerOp()));
273
274 // Components can be skipped if they are already disambiguated by their parent
275 // (or does not have a parent)
276 bool newRegion = true;
277 for (Component &c : llvm::reverse(components)) {
278 c.skipIf(newRegion && c.isUnique);
279
280 // non-skipped components disambiguate unique children
281 if (!c.skip)
282 newRegion = true;
283
284 // ...except canonical loops that need a suffix for each nest
285 if (!c.isRegionArgOfOp && c.getContainerOp())
286 newRegion = false;
287 }
288
289 // Compile the nesting name string
290 SmallString<64> Name{prefix};
291 llvm::raw_svector_ostream NameOS(Name);
292 for (auto &c : llvm::reverse(components)) {
293 if (c.skip)
294 continue;
295
296 if (c.isRegionArgOfOp)
297 NameOS << "_r" << c.getArgIdx();
298 else if (c.getLoopDepth() >= 1)
299 NameOS << "_d" << c.getLoopDepth();
300 else
301 NameOS << "_s" << c.getOpPos();
302 }
303
304 return NameOS.str().str();
305}
306
307void OpenMPDialect::initialize() {
308 addOperations<
309#define GET_OP_LIST
310#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
311 >();
312 addAttributes<
313#define GET_ATTRDEF_LIST
314#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
315 >();
316 addTypes<
317#define GET_TYPEDEF_LIST
318#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
319 >();
320
321 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
322
323 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
324 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
325 *getContext());
326
327 // Attach default offload module interface to module op to access
328 // offload functionality through
329 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
330 *getContext());
331
332 // Attach default declare target interfaces to operations which can be marked
333 // as declare target (Global Operations and Functions/Subroutines in dialects
334 // that Fortran (or other languages that lower to MLIR) translates too
335 mlir::LLVM::GlobalOp::attachInterface<
337 *getContext());
338 mlir::LLVM::LLVMFuncOp::attachInterface<
340 *getContext());
341 mlir::func::FuncOp::attachInterface<
343}
344
345//===----------------------------------------------------------------------===//
346// Parser and printer for Allocate Clause
347//===----------------------------------------------------------------------===//
348
349/// Parse an allocate clause with allocators and a list of operands with types.
350///
351/// allocate-operand-list :: = allocate-operand |
352/// allocator-operand `,` allocate-operand-list
353/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
354/// ssa-id-and-type ::= ssa-id `:` type
355static ParseResult parseAllocateAndAllocator(
356 OpAsmParser &parser,
358 SmallVectorImpl<Type> &allocateTypes,
360 SmallVectorImpl<Type> &allocatorTypes) {
361
362 return parser.parseCommaSeparatedList([&]() {
364 Type type;
365 if (parser.parseOperand(operand) || parser.parseColonType(type))
366 return failure();
367 allocatorVars.push_back(operand);
368 allocatorTypes.push_back(type);
369 if (parser.parseArrow())
370 return failure();
371 if (parser.parseOperand(operand) || parser.parseColonType(type))
372 return failure();
373
374 allocateVars.push_back(operand);
375 allocateTypes.push_back(type);
376 return success();
377 });
378}
379
380/// Print allocate clause
382 OperandRange allocateVars,
383 TypeRange allocateTypes,
384 OperandRange allocatorVars,
385 TypeRange allocatorTypes) {
386 for (unsigned i = 0; i < allocateVars.size(); ++i) {
387 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
388 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
389 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
390 }
391}
392
393//===----------------------------------------------------------------------===//
394// Parser and printer for a clause attribute (StringEnumAttr)
395//===----------------------------------------------------------------------===//
396
397template <typename ClauseAttr>
398static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
399 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
400 StringRef enumStr;
401 SMLoc loc = parser.getCurrentLocation();
402 if (parser.parseKeyword(&enumStr))
403 return failure();
404 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
405 attr = ClauseAttr::get(parser.getContext(), *enumValue);
406 return success();
407 }
408 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
409}
410
411template <typename ClauseAttr>
412static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
413 p << stringifyEnum(attr.getValue());
414}
415
416//===----------------------------------------------------------------------===//
417// Parser and printer for Linear Clause
418//===----------------------------------------------------------------------===//
419
420/// linear ::= `linear` `(` linear-list `)`
421/// linear-list := linear-val | linear-val linear-list
422/// linear-val := ssa-id-and-type `=` ssa-id-and-type
423/// | `val` `(` ssa-id-and-type `=` ssa-id-and-type `)`
424/// | `ref` `(` ssa-id-and-type `=` ssa-id-and-type `)`
425/// | `uval` `(` ssa-id-and-type `=` ssa-id-and-type `)`
426static ParseResult parseLinearClause(
427 OpAsmParser &parser,
429 SmallVectorImpl<Type> &linearTypes,
431 SmallVectorImpl<Type> &linearStepTypes, ArrayAttr &linearModifiers) {
432 SmallVector<Attribute> modifiers;
433 auto result = parser.parseCommaSeparatedList([&]() {
435 Type type, stepType;
437
438 std::optional<omp::LinearModifier> linearModifier;
439 if (succeeded(parser.parseOptionalKeyword("val"))) {
440 linearModifier = omp::LinearModifier::val;
441 } else if (succeeded(parser.parseOptionalKeyword("ref"))) {
442 linearModifier = omp::LinearModifier::ref;
443 } else if (succeeded(parser.parseOptionalKeyword("uval"))) {
444 linearModifier = omp::LinearModifier::uval;
445 }
446
447 bool hasLinearModifierParens = linearModifier.has_value();
448 if (hasLinearModifierParens && parser.parseLParen())
449 return failure();
450
451 if (parser.parseOperand(var) || parser.parseColonType(type) ||
452 parser.parseEqual() || parser.parseOperand(stepVar) ||
453 parser.parseColonType(stepType))
454 return failure();
455
456 if (hasLinearModifierParens && parser.parseRParen())
457 return failure();
458
459 linearVars.push_back(var);
460 linearTypes.push_back(type);
461 linearStepVars.push_back(stepVar);
462 linearStepTypes.push_back(stepType);
463 if (linearModifier) {
464 modifiers.push_back(
465 omp::LinearModifierAttr::get(parser.getContext(), *linearModifier));
466 } else {
467 modifiers.push_back(UnitAttr::get(parser.getContext()));
468 }
469 return success();
470 });
471 if (failed(result))
472 return failure();
473 linearModifiers = ArrayAttr::get(parser.getContext(), modifiers);
474 return success();
475}
476
477/// Print Linear Clause
479 ValueRange linearVars, TypeRange linearTypes,
480 ValueRange linearStepVars, TypeRange stepVarTypes,
481 ArrayAttr linearModifiers) {
482 size_t linearVarsSize = linearVars.size();
483 for (unsigned i = 0; i < linearVarsSize; ++i) {
484 if (i != 0)
485 p << ", ";
486 // Print modifier keyword wrapper if present.
487 Attribute modAttr = linearModifiers ? linearModifiers[i] : nullptr;
488 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) : nullptr;
489 if (mod) {
490 p << omp::stringifyLinearModifier(mod.getValue()) << "(";
491 }
492 p << linearVars[i] << " : " << linearTypes[i];
493 p << " = " << linearStepVars[i] << " : " << stepVarTypes[i];
494 if (mod)
495 p << ")";
496 }
497}
498
499//===----------------------------------------------------------------------===//
500// Verifier for Linear modifier
501//===----------------------------------------------------------------------===//
502
503/// OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or
504/// uval only on a declare simd directive."
505/// Also verifies that modifier count matches variable count.
506static LogicalResult
507verifyLinearModifiers(Operation *op, std::optional<ArrayAttr> linearModifiers,
508 OperandRange linearVars, bool isDeclareSimd = false) {
509 if (!linearModifiers)
510 return success();
511 if (linearModifiers->size() != linearVars.size())
512 return op->emitOpError()
513 << "expected as many linear modifiers as linear variables";
514 if (!isDeclareSimd) {
515 for (Attribute attr : *linearModifiers) {
516 if (!attr)
517 continue;
518 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
519 if (!modAttr)
520 continue;
521 omp::LinearModifier mod = modAttr.getValue();
522 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
523 return op->emitOpError()
524 << "linear modifier '" << omp::stringifyLinearModifier(mod)
525 << "' may only be specified on a declare simd directive";
526 }
527 }
528 return success();
529}
530
531//===----------------------------------------------------------------------===//
532// Verifier for Nontemporal Clause
533//===----------------------------------------------------------------------===//
534
535static LogicalResult verifyNontemporalClause(Operation *op,
536 OperandRange nontemporalVars) {
537
538 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
539 DenseSet<Value> nontemporalItems;
540 for (const auto &it : nontemporalVars)
541 if (!nontemporalItems.insert(it).second)
542 return op->emitOpError() << "nontemporal variable used more than once";
543
544 return success();
545}
546
547//===----------------------------------------------------------------------===//
548// Parser, verifier and printer for Aligned Clause
549//===----------------------------------------------------------------------===//
550static LogicalResult verifyAlignedClause(Operation *op,
551 std::optional<ArrayAttr> alignments,
552 OperandRange alignedVars) {
553 // Check if number of alignment values equals to number of aligned variables
554 if (!alignedVars.empty()) {
555 if (!alignments || alignments->size() != alignedVars.size())
556 return op->emitOpError()
557 << "expected as many alignment values as aligned variables";
558 } else {
559 if (alignments)
560 return op->emitOpError() << "unexpected alignment values attribute";
561 return success();
562 }
563
564 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
565 DenseSet<Value> alignedItems;
566 for (auto it : alignedVars)
567 if (!alignedItems.insert(it).second)
568 return op->emitOpError() << "aligned variable used more than once";
569
570 if (!alignments)
571 return success();
572
573 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
574 for (unsigned i = 0; i < (*alignments).size(); ++i) {
575 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
576 if (intAttr.getValue().sle(0))
577 return op->emitOpError() << "alignment should be greater than 0";
578 } else {
579 return op->emitOpError() << "expected integer alignment";
580 }
581 }
582
583 return success();
584}
585
586/// aligned ::= `aligned` `(` aligned-list `)`
587/// aligned-list := aligned-val | aligned-val aligned-list
588/// aligned-val := ssa-id-and-type `->` alignment
589static ParseResult
592 SmallVectorImpl<Type> &alignedTypes,
593 ArrayAttr &alignmentsAttr) {
594 SmallVector<Attribute> alignmentVec;
595 if (failed(parser.parseCommaSeparatedList([&]() {
596 if (parser.parseOperand(alignedVars.emplace_back()) ||
597 parser.parseColonType(alignedTypes.emplace_back()) ||
598 parser.parseArrow() ||
599 parser.parseAttribute(alignmentVec.emplace_back())) {
600 return failure();
601 }
602 return success();
603 })))
604 return failure();
605 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
606 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
607 return success();
608}
609
610/// Print Aligned Clause
612 ValueRange alignedVars, TypeRange alignedTypes,
613 std::optional<ArrayAttr> alignments) {
614 for (unsigned i = 0; i < alignedVars.size(); ++i) {
615 if (i != 0)
616 p << ", ";
617 p << alignedVars[i] << " : " << alignedVars[i].getType();
618 p << " -> " << (*alignments)[i];
619 }
620}
621
622//===----------------------------------------------------------------------===//
623// Parser, printer and verifier for Schedule Clause
624//===----------------------------------------------------------------------===//
625
626static ParseResult
628 SmallVectorImpl<SmallString<12>> &modifiers) {
629 if (modifiers.size() > 2)
630 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
631 for (const auto &mod : modifiers) {
632 // Translate the string. If it has no value, then it was not a valid
633 // modifier!
634 auto symbol = symbolizeScheduleModifier(mod);
635 if (!symbol)
636 return parser.emitError(parser.getNameLoc())
637 << " unknown modifier type: " << mod;
638 }
639
640 // If we have one modifier that is "simd", then stick a "none" modiifer in
641 // index 0.
642 if (modifiers.size() == 1) {
643 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
644 modifiers.push_back(modifiers[0]);
645 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
646 }
647 } else if (modifiers.size() == 2) {
648 // If there are two modifier:
649 // First modifier should not be simd, second one should be simd
650 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
651 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
652 return parser.emitError(parser.getNameLoc())
653 << " incorrect modifier order";
654 }
655 return success();
656}
657
658/// schedule ::= `schedule` `(` sched-list `)`
659/// sched-list ::= sched-val | sched-val sched-list |
660/// sched-val `,` sched-modifier
661/// sched-val ::= sched-with-chunk | sched-wo-chunk
662/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
663/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
664/// sched-wo-chunk ::= `auto` | `runtime`
665/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
666/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
667static ParseResult
668parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
669 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
670 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
671 Type &chunkType) {
672 StringRef keyword;
673 if (parser.parseKeyword(&keyword))
674 return failure();
675 std::optional<mlir::omp::ClauseScheduleKind> schedule =
676 symbolizeClauseScheduleKind(keyword);
677 if (!schedule)
678 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
679
680 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
681 switch (*schedule) {
682 case ClauseScheduleKind::Static:
683 case ClauseScheduleKind::Dynamic:
684 case ClauseScheduleKind::Guided:
685 if (succeeded(parser.parseOptionalEqual())) {
686 chunkSize = OpAsmParser::UnresolvedOperand{};
687 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
688 return failure();
689 } else {
690 chunkSize = std::nullopt;
691 }
692 break;
693 case ClauseScheduleKind::Auto:
694 case ClauseScheduleKind::Runtime:
695 case ClauseScheduleKind::Distribute:
696 chunkSize = std::nullopt;
697 }
698
699 // If there is a comma, we have one or more modifiers..
701 while (succeeded(parser.parseOptionalComma())) {
702 StringRef mod;
703 if (parser.parseKeyword(&mod))
704 return failure();
705 modifiers.push_back(mod);
706 }
707
708 if (verifyScheduleModifiers(parser, modifiers))
709 return failure();
710
711 if (!modifiers.empty()) {
712 SMLoc loc = parser.getCurrentLocation();
713 if (std::optional<ScheduleModifier> mod =
714 symbolizeScheduleModifier(modifiers[0])) {
715 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
716 } else {
717 return parser.emitError(loc, "invalid schedule modifier");
718 }
719 // Only SIMD attribute is allowed here!
720 if (modifiers.size() > 1) {
721 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
722 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
723 }
724 }
725
726 return success();
727}
728
729/// Print schedule clause
731 ClauseScheduleKindAttr scheduleKind,
732 ScheduleModifierAttr scheduleMod,
733 UnitAttr scheduleSimd, Value scheduleChunk,
734 Type scheduleChunkType) {
735 p << stringifyClauseScheduleKind(scheduleKind.getValue());
736 if (scheduleChunk)
737 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
738 if (scheduleMod)
739 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
740 if (scheduleSimd)
741 p << ", simd";
742}
743
744//===----------------------------------------------------------------------===//
745// Parser and printer for Order Clause
746//===----------------------------------------------------------------------===//
747
748// order ::= `order` `(` [order-modifier ':'] concurrent `)`
749// order-modifier ::= reproducible | unconstrained
750static ParseResult parseOrderClause(OpAsmParser &parser,
751 ClauseOrderKindAttr &order,
752 OrderModifierAttr &orderMod) {
753 StringRef enumStr;
754 SMLoc loc = parser.getCurrentLocation();
755 if (parser.parseKeyword(&enumStr))
756 return failure();
757 if (std::optional<OrderModifier> enumValue =
758 symbolizeOrderModifier(enumStr)) {
759 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
760 if (parser.parseOptionalColon())
761 return failure();
762 loc = parser.getCurrentLocation();
763 if (parser.parseKeyword(&enumStr))
764 return failure();
765 }
766 if (std::optional<ClauseOrderKind> enumValue =
767 symbolizeClauseOrderKind(enumStr)) {
768 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
769 return success();
770 }
771 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
772}
773
775 ClauseOrderKindAttr order,
776 OrderModifierAttr orderMod) {
777 if (orderMod)
778 p << stringifyOrderModifier(orderMod.getValue()) << ":";
779 if (order)
780 p << stringifyClauseOrderKind(order.getValue());
781}
782
783template <typename ClauseTypeAttr, typename ClauseType>
784static ParseResult
785parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
786 std::optional<OpAsmParser::UnresolvedOperand> &operand,
787 Type &operandType,
788 std::optional<ClauseType> (*symbolizeClause)(StringRef),
789 StringRef clauseName) {
790 StringRef enumStr;
791 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
792 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
793 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
794 if (parser.parseComma())
795 return failure();
796 } else {
797 return parser.emitError(parser.getCurrentLocation())
798 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
799 ;
800 }
801 }
802
804 if (succeeded(parser.parseOperand(var))) {
805 operand = var;
806 } else {
807 return parser.emitError(parser.getCurrentLocation())
808 << "expected " << clauseName << " operand";
809 }
810
811 if (operand.has_value()) {
812 if (parser.parseColonType(operandType))
813 return failure();
814 }
815
816 return success();
817}
818
819template <typename ClauseTypeAttr, typename ClauseType>
820static void
822 ClauseTypeAttr prescriptiveness, Value operand,
823 mlir::Type operandType,
824 StringRef (*stringifyClauseType)(ClauseType)) {
825
826 if (prescriptiveness)
827 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
828
829 if (operand)
830 p << operand << ": " << operandType;
831}
832
833//===----------------------------------------------------------------------===//
834// Parser and printer for grainsize Clause
835//===----------------------------------------------------------------------===//
836
837// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
838static ParseResult
839parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
840 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
841 Type &grainsizeType) {
843 parser, grainsizeMod, grainsize, grainsizeType,
844 &symbolizeClauseGrainsizeType, "grainsize");
845}
846
848 ClauseGrainsizeTypeAttr grainsizeMod,
849 Value grainsize, mlir::Type grainsizeType) {
851 p, op, grainsizeMod, grainsize, grainsizeType,
852 &stringifyClauseGrainsizeType);
853}
854
855//===----------------------------------------------------------------------===//
856// Parser and printer for num_tasks Clause
857//===----------------------------------------------------------------------===//
858
859// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
860static ParseResult
861parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
862 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
863 Type &numTasksType) {
865 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
866 "num_tasks");
867}
868
870 ClauseNumTasksTypeAttr numTasksMod,
871 Value numTasks, mlir::Type numTasksType) {
873 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
874}
875
876//===----------------------------------------------------------------------===//
877// Parsers for operations including clauses that define entry block arguments.
878//===----------------------------------------------------------------------===//
879
880namespace {
881struct MapParseArgs {
882 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
883 SmallVectorImpl<Type> &types;
884 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
885 SmallVectorImpl<Type> &types)
886 : vars(vars), types(types) {}
887};
888struct PrivateParseArgs {
889 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
890 llvm::SmallVectorImpl<Type> &types;
891 ArrayAttr &syms;
892 UnitAttr &needsBarrier;
893 DenseI64ArrayAttr *mapIndices;
894 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
895 SmallVectorImpl<Type> &types, ArrayAttr &syms,
896 UnitAttr &needsBarrier,
897 DenseI64ArrayAttr *mapIndices = nullptr)
898 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
899 mapIndices(mapIndices) {}
900};
901
902struct ReductionParseArgs {
903 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
904 SmallVectorImpl<Type> &types;
905 DenseBoolArrayAttr &byref;
906 ArrayAttr &syms;
907 ReductionModifierAttr *modifier;
908 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
909 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
910 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
911 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
912};
913
914struct AllRegionParseArgs {
915 std::optional<MapParseArgs> hasDeviceAddrArgs;
916 std::optional<MapParseArgs> hostEvalArgs;
917 std::optional<ReductionParseArgs> inReductionArgs;
918 std::optional<MapParseArgs> mapArgs;
919 std::optional<PrivateParseArgs> privateArgs;
920 std::optional<ReductionParseArgs> reductionArgs;
921 std::optional<ReductionParseArgs> taskReductionArgs;
922 std::optional<MapParseArgs> useDeviceAddrArgs;
923 std::optional<MapParseArgs> useDevicePtrArgs;
924};
925} // namespace
926
927static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
928 return "private_barrier";
929}
930
931static ParseResult parseClauseWithRegionArgs(
932 OpAsmParser &parser,
936 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
937 DenseBoolArrayAttr *byref = nullptr,
938 ReductionModifierAttr *modifier = nullptr,
939 UnitAttr *needsBarrier = nullptr) {
941 SmallVector<int64_t> mapIndicesVec;
942 SmallVector<bool> isByRefVec;
943 unsigned regionArgOffset = regionPrivateArgs.size();
944
945 if (parser.parseLParen())
946 return failure();
947
948 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
949 StringRef enumStr;
950 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
951 parser.parseComma())
952 return failure();
953 std::optional<ReductionModifier> enumValue =
954 symbolizeReductionModifier(enumStr);
955 if (!enumValue.has_value())
956 return failure();
957 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
958 if (!*modifier)
959 return failure();
960 }
961
962 if (parser.parseCommaSeparatedList([&]() {
963 if (byref)
964 isByRefVec.push_back(
965 parser.parseOptionalKeyword("byref").succeeded());
966
967 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
968 return failure();
969
970 if (parser.parseOperand(operands.emplace_back()) ||
971 parser.parseArrow() ||
972 parser.parseArgument(regionPrivateArgs.emplace_back()))
973 return failure();
974
975 if (mapIndices) {
976 if (parser.parseOptionalLSquare().succeeded()) {
977 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
978 parser.parseInteger(mapIndicesVec.emplace_back()) ||
979 parser.parseRSquare())
980 return failure();
981 } else {
982 mapIndicesVec.push_back(-1);
983 }
984 }
985
986 return success();
987 }))
988 return failure();
989
990 if (parser.parseColon())
991 return failure();
992
993 if (parser.parseCommaSeparatedList([&]() {
994 if (parser.parseType(types.emplace_back()))
995 return failure();
996
997 return success();
998 }))
999 return failure();
1000
1001 if (operands.size() != types.size())
1002 return failure();
1003
1004 if (parser.parseRParen())
1005 return failure();
1006
1007 if (needsBarrier) {
1009 .succeeded())
1010 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
1011 }
1012
1013 auto *argsBegin = regionPrivateArgs.begin();
1014 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
1015 argsBegin + regionArgOffset + types.size());
1016 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1017 prv.type = type;
1018 }
1019
1020 if (symbols) {
1021 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
1022 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
1023 }
1024
1025 if (!mapIndicesVec.empty())
1026 *mapIndices =
1027 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
1028
1029 if (byref)
1030 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
1031
1032 return success();
1033}
1034
1035static ParseResult parseBlockArgClause(
1036 OpAsmParser &parser,
1038 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1039 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1040 if (!mapArgs)
1041 return failure();
1042
1043 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
1044 entryBlockArgs)))
1045 return failure();
1046 }
1047 return success();
1048}
1049
1050static ParseResult parseBlockArgClause(
1051 OpAsmParser &parser,
1053 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1054 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1055 if (!privateArgs)
1056 return failure();
1057
1058 if (failed(parseClauseWithRegionArgs(
1059 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1060 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1061 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
1062 return failure();
1063 }
1064 return success();
1065}
1066
1067static ParseResult parseBlockArgClause(
1068 OpAsmParser &parser,
1070 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1071 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1072 if (!reductionArgs)
1073 return failure();
1074 if (failed(parseClauseWithRegionArgs(
1075 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1076 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1077 reductionArgs->modifier)))
1078 return failure();
1079 }
1080 return success();
1081}
1082
1083static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1084 AllRegionParseArgs args) {
1086
1087 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1088 args.hasDeviceAddrArgs)))
1089 return parser.emitError(parser.getCurrentLocation())
1090 << "invalid `has_device_addr` format";
1091
1092 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1093 args.hostEvalArgs)))
1094 return parser.emitError(parser.getCurrentLocation())
1095 << "invalid `host_eval` format";
1096
1097 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1098 args.inReductionArgs)))
1099 return parser.emitError(parser.getCurrentLocation())
1100 << "invalid `in_reduction` format";
1101
1102 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1103 args.mapArgs)))
1104 return parser.emitError(parser.getCurrentLocation())
1105 << "invalid `map_entries` format";
1106
1107 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1108 args.privateArgs)))
1109 return parser.emitError(parser.getCurrentLocation())
1110 << "invalid `private` format";
1111
1112 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1113 args.reductionArgs)))
1114 return parser.emitError(parser.getCurrentLocation())
1115 << "invalid `reduction` format";
1116
1117 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1118 args.taskReductionArgs)))
1119 return parser.emitError(parser.getCurrentLocation())
1120 << "invalid `task_reduction` format";
1121
1122 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1123 args.useDeviceAddrArgs)))
1124 return parser.emitError(parser.getCurrentLocation())
1125 << "invalid `use_device_addr` format";
1126
1127 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1128 args.useDevicePtrArgs)))
1129 return parser.emitError(parser.getCurrentLocation())
1130 << "invalid `use_device_addr` format";
1131
1132 return parser.parseRegion(region, entryBlockArgs);
1133}
1134
1135// These parseXyz functions correspond to the custom<Xyz> definitions
1136// in the .td file(s).
1137static ParseResult parseTargetOpRegion(
1138 OpAsmParser &parser, Region &region,
1140 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1142 SmallVectorImpl<Type> &hostEvalTypes,
1144 SmallVectorImpl<Type> &inReductionTypes,
1145 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1147 SmallVectorImpl<Type> &mapTypes,
1149 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1150 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1151 AllRegionParseArgs args;
1152 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1153 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1154 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1155 inReductionByref, inReductionSyms);
1156 args.mapArgs.emplace(mapVars, mapTypes);
1157 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1158 privateNeedsBarrier, &privateMaps);
1159 return parseBlockArgRegion(parser, region, args);
1160}
1161
1163 OpAsmParser &parser, Region &region,
1165 SmallVectorImpl<Type> &inReductionTypes,
1166 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1168 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1169 UnitAttr &privateNeedsBarrier) {
1170 AllRegionParseArgs args;
1171 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1172 inReductionByref, inReductionSyms);
1173 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1174 privateNeedsBarrier);
1175 return parseBlockArgRegion(parser, region, args);
1176}
1177
1179 OpAsmParser &parser, Region &region,
1181 SmallVectorImpl<Type> &inReductionTypes,
1182 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1184 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1185 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1187 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1188 ArrayAttr &reductionSyms) {
1189 AllRegionParseArgs args;
1190 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1191 inReductionByref, inReductionSyms);
1192 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1193 privateNeedsBarrier);
1194 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1195 reductionSyms, &reductionMod);
1196 return parseBlockArgRegion(parser, region, args);
1197}
1198
1199static ParseResult parsePrivateRegion(
1200 OpAsmParser &parser, Region &region,
1202 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1203 UnitAttr &privateNeedsBarrier) {
1204 AllRegionParseArgs args;
1205 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1206 privateNeedsBarrier);
1207 return parseBlockArgRegion(parser, region, args);
1208}
1209
1211 OpAsmParser &parser, Region &region,
1213 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1214 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1216 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1217 ArrayAttr &reductionSyms) {
1218 AllRegionParseArgs args;
1219 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1220 privateNeedsBarrier);
1221 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1222 reductionSyms, &reductionMod);
1223 return parseBlockArgRegion(parser, region, args);
1224}
1225
1226static ParseResult parseTaskReductionRegion(
1227 OpAsmParser &parser, Region &region,
1229 SmallVectorImpl<Type> &taskReductionTypes,
1230 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1231 AllRegionParseArgs args;
1232 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1233 taskReductionByref, taskReductionSyms);
1234 return parseBlockArgRegion(parser, region, args);
1235}
1236
1238 OpAsmParser &parser, Region &region,
1240 SmallVectorImpl<Type> &useDeviceAddrTypes,
1242 SmallVectorImpl<Type> &useDevicePtrTypes) {
1243 AllRegionParseArgs args;
1244 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1245 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1246 return parseBlockArgRegion(parser, region, args);
1247}
1248
1249//===----------------------------------------------------------------------===//
1250// Printers for operations including clauses that define entry block arguments.
1251//===----------------------------------------------------------------------===//
1252
1253namespace {
1254struct MapPrintArgs {
1255 ValueRange vars;
1256 TypeRange types;
1257 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1258};
1259struct PrivatePrintArgs {
1260 ValueRange vars;
1261 TypeRange types;
1262 ArrayAttr syms;
1263 UnitAttr needsBarrier;
1264 DenseI64ArrayAttr mapIndices;
1265 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1266 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1267 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1268 mapIndices(mapIndices) {}
1269};
1270struct ReductionPrintArgs {
1271 ValueRange vars;
1272 TypeRange types;
1273 DenseBoolArrayAttr byref;
1274 ArrayAttr syms;
1275 ReductionModifierAttr modifier;
1276 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1277 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1278 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1279};
1280struct AllRegionPrintArgs {
1281 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1282 std::optional<MapPrintArgs> hostEvalArgs;
1283 std::optional<ReductionPrintArgs> inReductionArgs;
1284 std::optional<MapPrintArgs> mapArgs;
1285 std::optional<PrivatePrintArgs> privateArgs;
1286 std::optional<ReductionPrintArgs> reductionArgs;
1287 std::optional<ReductionPrintArgs> taskReductionArgs;
1288 std::optional<MapPrintArgs> useDeviceAddrArgs;
1289 std::optional<MapPrintArgs> useDevicePtrArgs;
1290};
1291} // namespace
1292
1294 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1295 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1296 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1297 DenseBoolArrayAttr byref = nullptr,
1298 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1299 if (argsSubrange.empty())
1300 return;
1301
1302 p << clauseName << "(";
1303
1304 if (modifier)
1305 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1306
1307 if (!symbols) {
1308 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1309 symbols = ArrayAttr::get(ctx, values);
1310 }
1311
1312 if (!mapIndices) {
1313 llvm::SmallVector<int64_t> values(operands.size(), -1);
1314 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1315 }
1316
1317 if (!byref) {
1318 mlir::SmallVector<bool> values(operands.size(), false);
1319 byref = DenseBoolArrayAttr::get(ctx, values);
1320 }
1321
1322 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1323 mapIndices.asArrayRef(),
1324 byref.asArrayRef()),
1325 p, [&p](auto t) {
1326 auto [op, arg, sym, map, isByRef] = t;
1327 if (isByRef)
1328 p << "byref ";
1329 if (sym)
1330 p << sym << " ";
1331
1332 p << op << " -> " << arg;
1333
1334 if (map != -1)
1335 p << " [map_idx=" << map << "]";
1336 });
1337 p << " : ";
1338 llvm::interleaveComma(types, p);
1339 p << ") ";
1340
1341 if (needsBarrier)
1342 p << getPrivateNeedsBarrierSpelling() << " ";
1343}
1344
1346 StringRef clauseName, ValueRange argsSubrange,
1347 std::optional<MapPrintArgs> mapArgs) {
1348 if (mapArgs)
1349 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1350 mapArgs->types);
1351}
1352
1354 StringRef clauseName, ValueRange argsSubrange,
1355 std::optional<PrivatePrintArgs> privateArgs) {
1356 if (privateArgs)
1358 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1359 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1360 /*modifier=*/nullptr, privateArgs->needsBarrier);
1361}
1362
1363static void
1364printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1365 ValueRange argsSubrange,
1366 std::optional<ReductionPrintArgs> reductionArgs) {
1367 if (reductionArgs)
1368 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1369 reductionArgs->vars, reductionArgs->types,
1370 reductionArgs->syms, /*mapIndices=*/nullptr,
1371 reductionArgs->byref, reductionArgs->modifier);
1372}
1373
1375 const AllRegionPrintArgs &args) {
1376 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1377 MLIRContext *ctx = op->getContext();
1378
1379 printBlockArgClause(p, ctx, "has_device_addr",
1380 iface.getHasDeviceAddrBlockArgs(),
1381 args.hasDeviceAddrArgs);
1382 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1383 args.hostEvalArgs);
1384 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1385 args.inReductionArgs);
1386 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1387 args.mapArgs);
1388 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1389 args.privateArgs);
1390 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1391 args.reductionArgs);
1392 printBlockArgClause(p, ctx, "task_reduction",
1393 iface.getTaskReductionBlockArgs(),
1394 args.taskReductionArgs);
1395 printBlockArgClause(p, ctx, "use_device_addr",
1396 iface.getUseDeviceAddrBlockArgs(),
1397 args.useDeviceAddrArgs);
1398 printBlockArgClause(p, ctx, "use_device_ptr",
1399 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1400
1401 p.printRegion(region, /*printEntryBlockArgs=*/false);
1402}
1403
1404// These parseXyz functions correspond to the custom<Xyz> definitions
1405// in the .td file(s).
1407 OpAsmPrinter &p, Operation *op, Region &region,
1408 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1409 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1410 ValueRange inReductionVars, TypeRange inReductionTypes,
1411 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1412 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1413 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1414 DenseI64ArrayAttr privateMaps) {
1415 AllRegionPrintArgs args;
1416 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1417 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1418 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1419 inReductionByref, inReductionSyms);
1420 args.mapArgs.emplace(mapVars, mapTypes);
1421 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1422 privateNeedsBarrier, privateMaps);
1423 printBlockArgRegion(p, op, region, args);
1424}
1425
1427 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1428 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1429 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1430 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1431 AllRegionPrintArgs args;
1432 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1433 inReductionByref, inReductionSyms);
1434 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1435 privateNeedsBarrier,
1436 /*mapIndices=*/nullptr);
1437 printBlockArgRegion(p, op, region, args);
1438}
1439
1441 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1442 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1443 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1444 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1445 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1446 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1447 ArrayAttr reductionSyms) {
1448 AllRegionPrintArgs args;
1449 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1450 inReductionByref, inReductionSyms);
1451 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1452 privateNeedsBarrier,
1453 /*mapIndices=*/nullptr);
1454 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1455 reductionSyms, reductionMod);
1456 printBlockArgRegion(p, op, region, args);
1457}
1458
1460 ValueRange privateVars, TypeRange privateTypes,
1461 ArrayAttr privateSyms,
1462 UnitAttr privateNeedsBarrier) {
1463 AllRegionPrintArgs args;
1464 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1465 privateNeedsBarrier,
1466 /*mapIndices=*/nullptr);
1467 printBlockArgRegion(p, op, region, args);
1468}
1469
1471 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1472 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1473 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1474 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1475 ArrayAttr reductionSyms) {
1476 AllRegionPrintArgs args;
1477 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1478 privateNeedsBarrier,
1479 /*mapIndices=*/nullptr);
1480 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1481 reductionSyms, reductionMod);
1482 printBlockArgRegion(p, op, region, args);
1483}
1484
1486 Region &region,
1487 ValueRange taskReductionVars,
1488 TypeRange taskReductionTypes,
1489 DenseBoolArrayAttr taskReductionByref,
1490 ArrayAttr taskReductionSyms) {
1491 AllRegionPrintArgs args;
1492 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1493 taskReductionByref, taskReductionSyms);
1494 printBlockArgRegion(p, op, region, args);
1495}
1496
1498 Region &region,
1499 ValueRange useDeviceAddrVars,
1500 TypeRange useDeviceAddrTypes,
1501 ValueRange useDevicePtrVars,
1502 TypeRange useDevicePtrTypes) {
1503 AllRegionPrintArgs args;
1504 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1505 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1506 printBlockArgRegion(p, op, region, args);
1507}
1508
1509template <typename ParsePrefixFn>
1510static ParseResult parseSplitIteratedList(
1511 OpAsmParser &parser,
1513 SmallVectorImpl<Type> &iteratedTypes,
1515 SmallVectorImpl<Type> &plainTypes, ParsePrefixFn &&parsePrefix) {
1516
1517 return parser.parseCommaSeparatedList([&]() -> ParseResult {
1518 if (failed(parsePrefix()))
1519 return failure();
1520
1522 Type ty;
1523 if (parser.parseOperand(v) || parser.parseColonType(ty))
1524 return failure();
1525
1526 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1527 iteratedVars.push_back(v);
1528 iteratedTypes.push_back(ty);
1529 } else {
1530 plainVars.push_back(v);
1531 plainTypes.push_back(ty);
1532 }
1533 return success();
1534 });
1535}
1536
1537template <typename PrintPrefixFn>
1539 TypeRange iteratedTypes,
1540 ValueRange plainVars, TypeRange plainTypes,
1541 PrintPrefixFn &&printPrefixForPlain,
1542 PrintPrefixFn &&printPrefixForIterated) {
1543
1544 bool first = true;
1545 auto emit = [&](Value v, Type t, auto &&printPrefix) {
1546 if (!first)
1547 p << ", ";
1548 printPrefix(v, t);
1549 p << v << " : " << t;
1550 first = false;
1551 };
1552
1553 for (unsigned i = 0; i < iteratedVars.size(); ++i)
1554 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1555 for (unsigned i = 0; i < plainVars.size(); ++i)
1556 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1557}
1558
1559/// Verifies Reduction Clause
1560static LogicalResult
1561verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1562 OperandRange reductionVars,
1563 std::optional<ArrayRef<bool>> reductionByref) {
1564 if (!reductionVars.empty()) {
1565 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1566 return op->emitOpError()
1567 << "expected as many reduction symbol references "
1568 "as reduction variables";
1569 if (reductionByref && reductionByref->size() != reductionVars.size())
1570 return op->emitError() << "expected as many reduction variable by "
1571 "reference attributes as reduction variables";
1572 } else {
1573 if (reductionSyms)
1574 return op->emitOpError() << "unexpected reduction symbol references";
1575 return success();
1576 }
1577
1578 // TODO: The followings should be done in
1579 // SymbolUserOpInterface::verifySymbolUses.
1580 DenseSet<Value> accumulators;
1581 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1582 Value accum = std::get<0>(args);
1583
1584 if (!accumulators.insert(accum).second)
1585 return op->emitOpError() << "accumulator variable used more than once";
1586
1587 Type varType = accum.getType();
1588 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1589 auto decl =
1591 if (!decl)
1592 return op->emitOpError() << "expected symbol reference " << symbolRef
1593 << " to point to a reduction declaration";
1594
1595 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1596 return op->emitOpError()
1597 << "expected accumulator (" << varType
1598 << ") to be the same type as reduction declaration ("
1599 << decl.getAccumulatorType() << ")";
1600 }
1601
1602 return success();
1603}
1604
1605//===----------------------------------------------------------------------===//
1606// Parser, printer and verifier for Copyprivate
1607//===----------------------------------------------------------------------===//
1608
1609/// copyprivate-entry-list ::= copyprivate-entry
1610/// | copyprivate-entry-list `,` copyprivate-entry
1611/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1612static ParseResult parseCopyprivate(
1613 OpAsmParser &parser,
1615 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1617 if (failed(parser.parseCommaSeparatedList([&]() {
1618 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1619 parser.parseArrow() ||
1620 parser.parseAttribute(symsVec.emplace_back()) ||
1621 parser.parseColonType(copyprivateTypes.emplace_back()))
1622 return failure();
1623 return success();
1624 })))
1625 return failure();
1626 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1627 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1628 return success();
1629}
1630
1631/// Print Copyprivate clause
1633 OperandRange copyprivateVars,
1634 TypeRange copyprivateTypes,
1635 std::optional<ArrayAttr> copyprivateSyms) {
1636 if (!copyprivateSyms.has_value())
1637 return;
1638 llvm::interleaveComma(
1639 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1640 [&](const auto &args) {
1641 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1642 << std::get<2>(args);
1643 });
1644}
1645
1646/// Verifies CopyPrivate Clause
1647static LogicalResult
1649 std::optional<ArrayAttr> copyprivateSyms) {
1650 size_t copyprivateSymsSize =
1651 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1652 if (copyprivateSymsSize != copyprivateVars.size())
1653 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1654 << copyprivateVars.size()
1655 << ") and functions (= " << copyprivateSymsSize
1656 << "), both must be equal";
1657 if (!copyprivateSyms.has_value())
1658 return success();
1659
1660 for (auto copyprivateVarAndSym :
1661 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1662 auto symbolRef =
1663 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1664 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1665 funcOp;
1666 if (mlir::func::FuncOp mlirFuncOp =
1668 symbolRef))
1669 funcOp = mlirFuncOp;
1670 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1672 op, symbolRef))
1673 funcOp = llvmFuncOp;
1674
1675 auto getNumArguments = [&] {
1676 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1677 };
1678
1679 auto getArgumentType = [&](unsigned i) {
1680 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1681 *funcOp);
1682 };
1683
1684 if (!funcOp)
1685 return op->emitOpError() << "expected symbol reference " << symbolRef
1686 << " to point to a copy function";
1687
1688 if (getNumArguments() != 2)
1689 return op->emitOpError()
1690 << "expected copy function " << symbolRef << " to have 2 operands";
1691
1692 Type argTy = getArgumentType(0);
1693 if (argTy != getArgumentType(1))
1694 return op->emitOpError() << "expected copy function " << symbolRef
1695 << " arguments to have the same type";
1696
1697 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1698 if (argTy != varType)
1699 return op->emitOpError()
1700 << "expected copy function arguments' type (" << argTy
1701 << ") to be the same as copyprivate variable's type (" << varType
1702 << ")";
1703 }
1704
1705 return success();
1706}
1707
1708//===----------------------------------------------------------------------===//
1709// Parser, printer and verifier for DependVarList
1710//===----------------------------------------------------------------------===//
1711
1712/// depend-entry-list ::= depend-entry
1713/// | depend-entry-list `,` depend-entry
1714/// depend-entry ::= depend-kind `->` ssa-id `:` type
1715/// | depend-kind `->` ssa-id `:` iterated-type
1716static ParseResult parseDependVarList(
1717 OpAsmParser &parser,
1719 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds,
1721 SmallVectorImpl<Type> &iteratedTypes, ArrayAttr &iteratedKinds) {
1724 if (failed(parser.parseCommaSeparatedList([&]() {
1725 StringRef keyword;
1726 OpAsmParser::UnresolvedOperand operand;
1727 Type ty;
1728 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1729 parser.parseOperand(operand) || parser.parseColonType(ty))
1730 return failure();
1731 std::optional<ClauseTaskDepend> keywordDepend =
1732 symbolizeClauseTaskDepend(keyword);
1733 if (!keywordDepend)
1734 return failure();
1735 auto kindAttr =
1736 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1737 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1738 iteratedVars.push_back(operand);
1739 iteratedTypes.push_back(ty);
1740 iterKindsVec.push_back(kindAttr);
1741 } else {
1742 dependVars.push_back(operand);
1743 dependTypes.push_back(ty);
1744 kindsVec.push_back(kindAttr);
1745 }
1746 return success();
1747 })))
1748 return failure();
1749 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1750 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1751 SmallVector<Attribute> iterKinds(iterKindsVec.begin(), iterKindsVec.end());
1752 iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds);
1753 return success();
1754}
1755
1756/// Print Depend clause
1758 OperandRange dependVars, TypeRange dependTypes,
1759 std::optional<ArrayAttr> dependKinds,
1760 OperandRange iteratedVars,
1761 TypeRange iteratedTypes,
1762 std::optional<ArrayAttr> iteratedKinds) {
1763 bool first = true;
1764 auto printEntries = [&](OperandRange vars, TypeRange types,
1765 std::optional<ArrayAttr> kinds) {
1766 for (unsigned i = 0, e = vars.size(); i < e; ++i) {
1767 if (!first)
1768 p << ", ";
1769 p << stringifyClauseTaskDepend(
1770 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1771 .getValue())
1772 << " -> " << vars[i] << " : " << types[i];
1773 first = false;
1774 }
1775 };
1776 printEntries(dependVars, dependTypes, dependKinds);
1777 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1778}
1779
1780/// Verifies Depend clause
1781static LogicalResult verifyDependVarList(Operation *op,
1782 std::optional<ArrayAttr> dependKinds,
1783 OperandRange dependVars,
1784 std::optional<ArrayAttr> iteratedKinds,
1785 OperandRange iteratedVars) {
1786 if (!dependVars.empty()) {
1787 if (!dependKinds || dependKinds->size() != dependVars.size())
1788 return op->emitOpError() << "expected as many depend values"
1789 " as depend variables";
1790 } else {
1791 if (dependKinds && !dependKinds->empty())
1792 return op->emitOpError() << "unexpected depend values";
1793 }
1794
1795 if (!iteratedVars.empty()) {
1796 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1797 return op->emitOpError() << "expected as many depend iterated values"
1798 " as depend iterated variables";
1799 } else {
1800 if (iteratedKinds && !iteratedKinds->empty())
1801 return op->emitOpError() << "unexpected depend iterated values";
1802 }
1803
1804 return success();
1805}
1806
1807//===----------------------------------------------------------------------===//
1808// Parser, printer and verifier for Synchronization Hint (2.17.12)
1809//===----------------------------------------------------------------------===//
1810
1811/// Parses a Synchronization Hint clause. The value of hint is an integer
1812/// which is a combination of different hints from `omp_sync_hint_t`.
1813///
1814/// hint-clause = `hint` `(` hint-value `)`
1815static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1816 IntegerAttr &hintAttr) {
1817 StringRef hintKeyword;
1818 int64_t hint = 0;
1819 if (succeeded(parser.parseOptionalKeyword("none"))) {
1820 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1821 return success();
1822 }
1823 auto parseKeyword = [&]() -> ParseResult {
1824 if (failed(parser.parseKeyword(&hintKeyword)))
1825 return failure();
1826 if (hintKeyword == "uncontended")
1827 hint |= 1;
1828 else if (hintKeyword == "contended")
1829 hint |= 2;
1830 else if (hintKeyword == "nonspeculative")
1831 hint |= 4;
1832 else if (hintKeyword == "speculative")
1833 hint |= 8;
1834 else
1835 return parser.emitError(parser.getCurrentLocation())
1836 << hintKeyword << " is not a valid hint";
1837 return success();
1838 };
1839 if (parser.parseCommaSeparatedList(parseKeyword))
1840 return failure();
1841 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1842 return success();
1843}
1844
1845/// Prints a Synchronization Hint clause
1847 IntegerAttr hintAttr) {
1848 int64_t hint = hintAttr.getInt();
1849
1850 if (hint == 0) {
1851 p << "none";
1852 return;
1853 }
1854
1855 // Helper function to get n-th bit from the right end of `value`
1856 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1857
1858 bool uncontended = bitn(hint, 0);
1859 bool contended = bitn(hint, 1);
1860 bool nonspeculative = bitn(hint, 2);
1861 bool speculative = bitn(hint, 3);
1862
1864 if (uncontended)
1865 hints.push_back("uncontended");
1866 if (contended)
1867 hints.push_back("contended");
1868 if (nonspeculative)
1869 hints.push_back("nonspeculative");
1870 if (speculative)
1871 hints.push_back("speculative");
1872
1873 llvm::interleaveComma(hints, p);
1874}
1875
1876/// Verifies a synchronization hint clause
1877static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1878
1879 // Helper function to get n-th bit from the right end of `value`
1880 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1881
1882 bool uncontended = bitn(hint, 0);
1883 bool contended = bitn(hint, 1);
1884 bool nonspeculative = bitn(hint, 2);
1885 bool speculative = bitn(hint, 3);
1886
1887 if (uncontended && contended)
1888 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1889 "omp_sync_hint_contended cannot be combined";
1890 if (nonspeculative && speculative)
1891 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1892 "omp_sync_hint_speculative cannot be combined.";
1893 return success();
1894}
1895
1896//===----------------------------------------------------------------------===//
1897// Parser, printer and verifier for Target
1898//===----------------------------------------------------------------------===//
1899
1900// Helper function to get bitwise AND of `value` and 'flag' then return it as a
1901// boolean
1902static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
1903 return (value & flag) == flag;
1904}
1905
1906/// Parses a map_entries map type from a string format back into its numeric
1907/// value.
1908///
1909/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1910/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1911static ParseResult parseMapClause(OpAsmParser &parser,
1912 ClauseMapFlagsAttr &mapType) {
1913 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1914 // This simply verifies the correct keyword is read in, the
1915 // keyword itself is stored inside of the operation
1916 auto parseTypeAndMod = [&]() -> ParseResult {
1917 StringRef mapTypeMod;
1918 if (parser.parseKeyword(&mapTypeMod))
1919 return failure();
1920
1921 if (mapTypeMod == "always")
1922 mapTypeBits |= ClauseMapFlags::always;
1923
1924 if (mapTypeMod == "implicit")
1925 mapTypeBits |= ClauseMapFlags::implicit;
1926
1927 if (mapTypeMod == "ompx_hold")
1928 mapTypeBits |= ClauseMapFlags::ompx_hold;
1929
1930 if (mapTypeMod == "close")
1931 mapTypeBits |= ClauseMapFlags::close;
1932
1933 if (mapTypeMod == "present")
1934 mapTypeBits |= ClauseMapFlags::present;
1935
1936 if (mapTypeMod == "to")
1937 mapTypeBits |= ClauseMapFlags::to;
1938
1939 if (mapTypeMod == "from")
1940 mapTypeBits |= ClauseMapFlags::from;
1941
1942 if (mapTypeMod == "tofrom")
1943 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1944
1945 if (mapTypeMod == "delete")
1946 mapTypeBits |= ClauseMapFlags::del;
1947
1948 if (mapTypeMod == "storage")
1949 mapTypeBits |= ClauseMapFlags::storage;
1950
1951 if (mapTypeMod == "return_param")
1952 mapTypeBits |= ClauseMapFlags::return_param;
1953
1954 if (mapTypeMod == "private")
1955 mapTypeBits |= ClauseMapFlags::priv;
1956
1957 if (mapTypeMod == "literal")
1958 mapTypeBits |= ClauseMapFlags::literal;
1959
1960 if (mapTypeMod == "attach")
1961 mapTypeBits |= ClauseMapFlags::attach;
1962
1963 if (mapTypeMod == "attach_always")
1964 mapTypeBits |= ClauseMapFlags::attach_always;
1965
1966 if (mapTypeMod == "attach_never")
1967 mapTypeBits |= ClauseMapFlags::attach_never;
1968
1969 if (mapTypeMod == "attach_auto")
1970 mapTypeBits |= ClauseMapFlags::attach_auto;
1971
1972 if (mapTypeMod == "ref_ptr")
1973 mapTypeBits |= ClauseMapFlags::ref_ptr;
1974
1975 if (mapTypeMod == "ref_ptee")
1976 mapTypeBits |= ClauseMapFlags::ref_ptee;
1977
1978 if (mapTypeMod == "ref_ptr_ptee")
1979 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1980
1981 if (mapTypeMod == "is_device_ptr")
1982 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1983
1984 return success();
1985 };
1986
1987 if (parser.parseCommaSeparatedList(parseTypeAndMod))
1988 return failure();
1989
1990 mapType =
1991 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
1992
1993 return success();
1994}
1995
1996/// Prints a map_entries map type from its numeric value out into its string
1997/// format.
1998static void printMapClause(OpAsmPrinter &p, Operation *op,
1999 ClauseMapFlagsAttr mapType) {
2001 ClauseMapFlags mapFlags = mapType.getValue();
2002
2003 // handling of always, close, present placed at the beginning of the string
2004 // to aid readability
2005 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
2006 mapTypeStrs.push_back("always");
2007 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
2008 mapTypeStrs.push_back("implicit");
2009 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
2010 mapTypeStrs.push_back("ompx_hold");
2011 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
2012 mapTypeStrs.push_back("close");
2013 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
2014 mapTypeStrs.push_back("present");
2015
2016 // special handling of to/from/tofrom/delete and release/alloc, release +
2017 // alloc are the abscense of one of the other flags, whereas tofrom requires
2018 // both the to and from flag to be set.
2019 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
2020 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
2021
2022 if (to && from)
2023 mapTypeStrs.push_back("tofrom");
2024 else if (from)
2025 mapTypeStrs.push_back("from");
2026 else if (to)
2027 mapTypeStrs.push_back("to");
2028
2029 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
2030 mapTypeStrs.push_back("delete");
2031 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
2032 mapTypeStrs.push_back("return_param");
2033 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
2034 mapTypeStrs.push_back("storage");
2035 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
2036 mapTypeStrs.push_back("private");
2037 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
2038 mapTypeStrs.push_back("literal");
2039 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
2040 mapTypeStrs.push_back("attach");
2041 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
2042 mapTypeStrs.push_back("attach_always");
2043 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
2044 mapTypeStrs.push_back("attach_never");
2045 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
2046 mapTypeStrs.push_back("attach_auto");
2047 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
2048 mapTypeStrs.push_back("ref_ptr");
2049 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
2050 mapTypeStrs.push_back("ref_ptee");
2051 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
2052 mapTypeStrs.push_back("ref_ptr_ptee");
2053 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
2054 mapTypeStrs.push_back("is_device_ptr");
2055 if (mapFlags == ClauseMapFlags::none)
2056 mapTypeStrs.push_back("none");
2057
2058 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2059 p << mapTypeStrs[i];
2060 if (i + 1 < mapTypeStrs.size()) {
2061 p << ", ";
2062 }
2063 }
2064}
2065
2066static ParseResult parseMembersIndex(OpAsmParser &parser,
2067 ArrayAttr &membersIdx) {
2068 SmallVector<Attribute> values, memberIdxs;
2069
2070 auto parseIndices = [&]() -> ParseResult {
2071 int64_t value;
2072 if (parser.parseInteger(value))
2073 return failure();
2074 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
2075 APInt(64, value, /*isSigned=*/false)));
2076 return success();
2077 };
2078
2079 do {
2080 if (failed(parser.parseLSquare()))
2081 return failure();
2082
2083 if (parser.parseCommaSeparatedList(parseIndices))
2084 return failure();
2085
2086 if (failed(parser.parseRSquare()))
2087 return failure();
2088
2089 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
2090 values.clear();
2091 } while (succeeded(parser.parseOptionalComma()));
2092
2093 if (!memberIdxs.empty())
2094 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
2095
2096 return success();
2097}
2098
2099static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
2100 ArrayAttr membersIdx) {
2101 if (!membersIdx)
2102 return;
2103
2104 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
2105 p << "[";
2106 auto memberIdx = cast<ArrayAttr>(v);
2107 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
2108 p << cast<IntegerAttr>(v2).getInt();
2109 });
2110 p << "]";
2111 });
2112}
2113
2115 VariableCaptureKindAttr mapCaptureType) {
2116 std::string typeCapStr;
2117 llvm::raw_string_ostream typeCap(typeCapStr);
2118 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2119 typeCap << "ByRef";
2120 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2121 typeCap << "ByCopy";
2122 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2123 typeCap << "VLAType";
2124 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2125 typeCap << "This";
2126 p << typeCapStr;
2127}
2128
2129static ParseResult parseCaptureType(OpAsmParser &parser,
2130 VariableCaptureKindAttr &mapCaptureType) {
2131 StringRef mapCaptureKey;
2132 if (parser.parseKeyword(&mapCaptureKey))
2133 return failure();
2134
2135 if (mapCaptureKey == "This")
2136 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2137 parser.getContext(), mlir::omp::VariableCaptureKind::This);
2138 if (mapCaptureKey == "ByRef")
2139 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2140 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
2141 if (mapCaptureKey == "ByCopy")
2142 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2143 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2144 if (mapCaptureKey == "VLAType")
2145 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2146 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
2147
2148 return success();
2149}
2150
2151static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
2154
2155 for (auto mapOp : mapVars) {
2156 if (!mapOp.getDefiningOp())
2157 return emitError(op->getLoc(), "missing map operation");
2158
2159 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2160 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2161
2162 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2163 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2164 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2165
2166 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2167 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2168 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2169
2170 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2171 return emitError(op->getLoc(),
2172 "to, from, tofrom and alloc map types are permitted");
2173
2174 if (isa<TargetEnterDataOp>(op) && (from || del))
2175 return emitError(op->getLoc(), "to and alloc map types are permitted");
2176
2177 if (isa<TargetExitDataOp>(op) && to)
2178 return emitError(op->getLoc(),
2179 "from, release and delete map types are permitted");
2180
2181 if (isa<TargetUpdateOp>(op)) {
2182 if (del) {
2183 return emitError(op->getLoc(),
2184 "at least one of to or from map types must be "
2185 "specified, other map types are not permitted");
2186 }
2187
2188 if (!to && !from) {
2189 return emitError(op->getLoc(),
2190 "at least one of to or from map types must be "
2191 "specified, other map types are not permitted");
2192 }
2193
2194 auto updateVar = mapInfoOp.getVarPtr();
2195
2196 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2197 (from && updateToVars.contains(updateVar))) {
2198 return emitError(
2199 op->getLoc(),
2200 "either to or from map types can be specified, not both");
2201 }
2202
2203 if (always || close || implicit) {
2204 return emitError(
2205 op->getLoc(),
2206 "present, mapper and iterator map type modifiers are permitted");
2207 }
2208
2209 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2210 }
2211 } else if (!isa<DeclareMapperInfoOp>(op)) {
2212 return emitError(op->getLoc(),
2213 "map argument is not a map entry operation");
2214 }
2215 }
2216
2217 return success();
2218}
2219
2220static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2221 std::optional<DenseI64ArrayAttr> privateMapIndices =
2222 targetOp.getPrivateMapsAttr();
2223
2224 // None of the private operands are mapped.
2225 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2226 return success();
2227
2228 OperandRange privateVars = targetOp.getPrivateVars();
2229
2230 if (privateMapIndices.value().size() !=
2231 static_cast<int64_t>(privateVars.size()))
2232 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2233 "`private_maps` attribute mismatch");
2234
2235 return success();
2236}
2237
2238//===----------------------------------------------------------------------===//
2239// MapInfoOp
2240//===----------------------------------------------------------------------===//
2241
2242static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2243 StringRef clauseName,
2244 OperandRange vars) {
2245 for (Value var : vars)
2246 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2247 return op->emitOpError()
2248 << "'" << clauseName
2249 << "' arguments must be defined by 'omp.map.info' ops";
2250 return success();
2251}
2252
2253LogicalResult MapInfoOp::verify() {
2254 if (getMapperId() &&
2256 *this, getMapperIdAttr())) {
2257 return emitError("invalid mapper id");
2258 }
2259
2260 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2261 return failure();
2262
2263 return success();
2264}
2265
2266//===----------------------------------------------------------------------===//
2267// TargetDataOp
2268//===----------------------------------------------------------------------===//
2269
2270void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2271 const TargetDataOperands &clauses) {
2272 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2273 clauses.mapVars, clauses.useDeviceAddrVars,
2274 clauses.useDevicePtrVars);
2275}
2276
2277LogicalResult TargetDataOp::verify() {
2278 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2279 getUseDeviceAddrVars().empty()) {
2280 return ::emitError(this->getLoc(),
2281 "At least one of map, use_device_ptr_vars, or "
2282 "use_device_addr_vars operand must be present");
2283 }
2284
2285 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2286 getUseDevicePtrVars())))
2287 return failure();
2288
2289 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2290 getUseDeviceAddrVars())))
2291 return failure();
2292
2293 return verifyMapClause(*this, getMapVars());
2294}
2295
2296//===----------------------------------------------------------------------===//
2297// TargetEnterDataOp
2298//===----------------------------------------------------------------------===//
2299
2300void TargetEnterDataOp::build(
2301 OpBuilder &builder, OperationState &state,
2302 const TargetEnterExitUpdateDataOperands &clauses) {
2303 MLIRContext *ctx = builder.getContext();
2304 TargetEnterDataOp::build(
2305 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2306 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2307 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2308 clauses.nowait);
2309}
2310
2311LogicalResult TargetEnterDataOp::verify() {
2312 LogicalResult verifyDependVars =
2313 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2314 getDependIteratedKinds(), getDependIterated());
2315 return failed(verifyDependVars) ? verifyDependVars
2316 : verifyMapClause(*this, getMapVars());
2317}
2318
2319//===----------------------------------------------------------------------===//
2320// TargetExitDataOp
2321//===----------------------------------------------------------------------===//
2322
2323void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2324 const TargetEnterExitUpdateDataOperands &clauses) {
2325 MLIRContext *ctx = builder.getContext();
2326 TargetExitDataOp::build(
2327 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2328 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2329 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2330 clauses.nowait);
2331}
2332
2333LogicalResult TargetExitDataOp::verify() {
2334 LogicalResult verifyDependVars =
2335 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2336 getDependIteratedKinds(), getDependIterated());
2337 return failed(verifyDependVars) ? verifyDependVars
2338 : verifyMapClause(*this, getMapVars());
2339}
2340
2341//===----------------------------------------------------------------------===//
2342// TargetUpdateOp
2343//===----------------------------------------------------------------------===//
2344
2345void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2346 const TargetEnterExitUpdateDataOperands &clauses) {
2347 MLIRContext *ctx = builder.getContext();
2348 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2349 clauses.dependVars,
2350 makeArrayAttr(ctx, clauses.dependIteratedKinds),
2351 clauses.dependIterated, clauses.device, clauses.ifExpr,
2352 clauses.mapVars, clauses.nowait);
2353}
2354
2355LogicalResult TargetUpdateOp::verify() {
2356 LogicalResult verifyDependVars =
2357 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2358 getDependIteratedKinds(), getDependIterated());
2359 return failed(verifyDependVars) ? verifyDependVars
2360 : verifyMapClause(*this, getMapVars());
2361}
2362
2363//===----------------------------------------------------------------------===//
2364// TargetOp
2365//===----------------------------------------------------------------------===//
2366
2367void TargetOp::build(OpBuilder &builder, OperationState &state,
2368 const TargetOperands &clauses) {
2369 MLIRContext *ctx = builder.getContext();
2370 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2371 // inReductionByref, inReductionSyms.
2372 TargetOp::build(
2373 builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
2374 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2375 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2376 clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
2377 clauses.ifExpr,
2378 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2379 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2380 clauses.nowait, clauses.privateVars,
2381 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2382 clauses.threadLimitVars,
2383 /*private_maps=*/nullptr);
2384}
2385
2386LogicalResult TargetOp::verify() {
2387 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
2388 getDependIteratedKinds(),
2389 getDependIterated())))
2390 return failure();
2391
2392 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2393 getHasDeviceAddrVars())))
2394 return failure();
2395
2396 if (failed(verifyMapClause(*this, getMapVars())))
2397 return failure();
2398
2399 return verifyPrivateVarsMapping(*this);
2400}
2401
2402LogicalResult TargetOp::verifyRegions() {
2403 auto teamsOps = getOps<TeamsOp>();
2404 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2405 return emitError("target containing multiple 'omp.teams' nested ops");
2406
2407 // Check that host_eval values are only used in legal ways.
2408 Operation *capturedOp = getInnermostCapturedOmpOp();
2409 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2410 for (Value hostEvalArg :
2411 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2412 for (Operation *user : hostEvalArg.getUsers()) {
2413 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2414 // Check if used in num_teams_lower or any of num_teams_upper_vars
2415 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2416 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2417 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2418 continue;
2419
2420 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2421 "and 'thread_limit' in 'omp.teams'";
2422 }
2423 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2424 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2425 parallelOp->isAncestor(capturedOp) &&
2426 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2427 continue;
2428
2429 return emitOpError()
2430 << "host_eval argument only legal as 'num_threads' in "
2431 "'omp.parallel' when representing target SPMD";
2432 }
2433 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2434 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2435 loopNestOp.getOperation() == capturedOp &&
2436 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2437 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2438 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2439 continue;
2440
2441 return emitOpError() << "host_eval argument only legal as loop bounds "
2442 "and steps in 'omp.loop_nest' when trip count "
2443 "must be evaluated in the host";
2444 }
2445
2446 return emitOpError() << "host_eval argument illegal use in '"
2447 << user->getName() << "' operation";
2448 }
2449 }
2450 return success();
2451}
2452
2453static Operation *
2454findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2455 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2456 assert(rootOp && "expected valid operation");
2457
2458 Dialect *ompDialect = rootOp->getDialect();
2459 Operation *capturedOp = nullptr;
2460 DominanceInfo domInfo;
2461
2462 // Process in pre-order to check operations from outermost to innermost,
2463 // ensuring we only enter the region of an operation if it meets the criteria
2464 // for being captured. We stop the exploration of nested operations as soon as
2465 // we process a region holding no operations to be captured.
2466 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2467 if (op == rootOp)
2468 return WalkResult::advance();
2469
2470 // Ignore operations of other dialects or omp operations with no regions,
2471 // because these will only be checked if they are siblings of an omp
2472 // operation that can potentially be captured.
2473 bool isOmpDialect = op->getDialect() == ompDialect;
2474 bool hasRegions = op->getNumRegions() > 0;
2475 if (!isOmpDialect || !hasRegions)
2476 return WalkResult::skip();
2477
2478 // This operation cannot be captured if it can be executed more than once
2479 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2480 // be executed before all exits of the region (i.e. it doesn't dominate all
2481 // blocks with no successors reachable from the entry block).
2482 if (checkSingleMandatoryExec) {
2483 Region *parentRegion = op->getParentRegion();
2484 Block *parentBlock = op->getBlock();
2485
2486 for (Block *successor : parentBlock->getSuccessors())
2487 if (successor->isReachable(parentBlock))
2488 return WalkResult::interrupt();
2489
2490 for (Block &block : *parentRegion)
2491 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2492 !domInfo.dominates(parentBlock, &block))
2493 return WalkResult::interrupt();
2494 }
2495
2496 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2497 // into nested operations.
2498 for (Operation &sibling : op->getParentRegion()->getOps())
2499 if (&sibling != op && !siblingAllowedFn(&sibling))
2500 return WalkResult::interrupt();
2501
2502 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2503 // Otherwise, process the contents of this operation.
2504 capturedOp = op;
2505 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2507 });
2508
2509 return capturedOp;
2510}
2511
2512Operation *TargetOp::getInnermostCapturedOmpOp() {
2513 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2514
2515 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2516 // effects, but don't include a memory write effect.
2517 return findCapturedOmpOp(
2518 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2519 if (!sibling)
2520 return false;
2521
2522 if (ompDialect == sibling->getDialect())
2523 return sibling->hasTrait<OpTrait::IsTerminator>();
2524
2525 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2527 effects;
2528 memOp.getEffects(effects);
2529 return !llvm::any_of(
2530 effects, [&](MemoryEffects::EffectInstance &effect) {
2531 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2532 isa<SideEffects::AutomaticAllocationScopeResource>(
2533 effect.getResource());
2534 });
2535 }
2536 return true;
2537 });
2538}
2539
2540/// Check if we can promote SPMD kernel to No-Loop kernel.
2541static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2542 WsloopOp *wsLoopOp) {
2543 // num_teams clause can break no-loop teams/threads assumption.
2544 if (!teamsOp.getNumTeamsUpperVars().empty())
2545 return false;
2546
2547 // Reduction kernels are slower in no-loop mode.
2548 if (teamsOp.getNumReductionVars())
2549 return false;
2550 if (wsLoopOp->getNumReductionVars())
2551 return false;
2552
2553 // Check if the user allows the promotion of kernels to no-loop mode.
2554 OffloadModuleInterface offloadMod =
2555 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2556 if (!offloadMod)
2557 return false;
2558 auto ompFlags = offloadMod.getFlags();
2559 if (!ompFlags)
2560 return false;
2561 return ompFlags.getAssumeTeamsOversubscription() &&
2562 ompFlags.getAssumeThreadsOversubscription();
2563}
2564
2565TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2566 // A non-null captured op is only valid if it resides inside of a TargetOp
2567 // and is the result of calling getInnermostCapturedOmpOp() on it.
2568 TargetOp targetOp =
2569 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2570 assert((!capturedOp ||
2571 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2572 "unexpected captured op");
2573
2574 // If it's not capturing a loop, it's a default target region.
2575 if (!isa_and_present<LoopNestOp>(capturedOp))
2576 return TargetRegionFlags::generic;
2577
2578 // Get the innermost non-simd loop wrapper.
2580 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2581 assert(!loopWrappers.empty());
2582
2583 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2584 if (isa<SimdOp>(innermostWrapper))
2585 innermostWrapper = std::next(innermostWrapper);
2586
2587 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2588 if (numWrappers != 1 && numWrappers != 2)
2589 return TargetRegionFlags::generic;
2590
2591 // Detect target-teams-distribute-parallel-wsloop[-simd].
2592 if (numWrappers == 2) {
2593 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2594 if (!wsloopOp)
2595 return TargetRegionFlags::generic;
2596
2597 innermostWrapper = std::next(innermostWrapper);
2598 if (!isa<DistributeOp>(innermostWrapper))
2599 return TargetRegionFlags::generic;
2600
2601 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2602 if (!isa_and_present<ParallelOp>(parallelOp))
2603 return TargetRegionFlags::generic;
2604
2605 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2606 if (!teamsOp)
2607 return TargetRegionFlags::generic;
2608
2609 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2610 TargetRegionFlags result =
2611 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2612 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2613 result = result | TargetRegionFlags::no_loop;
2614 return result;
2615 }
2616 }
2617 // Detect target-teams-distribute[-simd] and target-teams-loop.
2618 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2619 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2620 if (!isa_and_present<TeamsOp>(teamsOp))
2621 return TargetRegionFlags::generic;
2622
2623 if (teamsOp->getParentOp() != targetOp.getOperation())
2624 return TargetRegionFlags::generic;
2625
2626 if (isa<LoopOp>(innermostWrapper))
2627 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2628
2629 // Find single immediately nested captured omp.parallel and add spmd flag
2630 // (generic-spmd case).
2631 //
2632 // TODO: This shouldn't have to be done here, as it is too easy to break.
2633 // The openmp-opt pass should be updated to be able to promote kernels like
2634 // this from "Generic" to "Generic-SPMD". However, the use of the
2635 // `kmpc_distribute_static_loop` family of functions produced by the
2636 // OMPIRBuilder for these kernels prevents that from working.
2637 Dialect *ompDialect = targetOp->getDialect();
2638 Operation *nestedCapture = findCapturedOmpOp(
2639 capturedOp, /*checkSingleMandatoryExec=*/false,
2640 [&](Operation *sibling) {
2641 return sibling && (ompDialect != sibling->getDialect() ||
2642 sibling->hasTrait<OpTrait::IsTerminator>());
2643 });
2644
2645 TargetRegionFlags result =
2646 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2647
2648 if (!nestedCapture)
2649 return result;
2650
2651 while (nestedCapture->getParentOp() != capturedOp)
2652 nestedCapture = nestedCapture->getParentOp();
2653
2654 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2655 : result;
2656 }
2657 // Detect target-parallel-wsloop[-simd].
2658 else if (isa<WsloopOp>(innermostWrapper)) {
2659 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2660 if (!isa_and_present<ParallelOp>(parallelOp))
2661 return TargetRegionFlags::generic;
2662
2663 if (parallelOp->getParentOp() == targetOp.getOperation())
2664 return TargetRegionFlags::spmd;
2665 }
2666
2667 return TargetRegionFlags::generic;
2668}
2669
2670//===----------------------------------------------------------------------===//
2671// ParallelOp
2672//===----------------------------------------------------------------------===//
2673
2674void ParallelOp::build(OpBuilder &builder, OperationState &state,
2675 ArrayRef<NamedAttribute> attributes) {
2676 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2677 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2678 /*num_threads_vars=*/ValueRange(),
2679 /*private_vars=*/ValueRange(),
2680 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2681 /*proc_bind_kind=*/nullptr,
2682 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2683 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2684 state.addAttributes(attributes);
2685}
2686
2687void ParallelOp::build(OpBuilder &builder, OperationState &state,
2688 const ParallelOperands &clauses) {
2689 MLIRContext *ctx = builder.getContext();
2690 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2691 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2692 makeArrayAttr(ctx, clauses.privateSyms),
2693 clauses.privateNeedsBarrier, clauses.procBindKind,
2694 clauses.reductionMod, clauses.reductionVars,
2695 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2696 makeArrayAttr(ctx, clauses.reductionSyms));
2697}
2698
2699template <typename OpType>
2700static LogicalResult verifyPrivateVarList(OpType &op) {
2701 auto privateVars = op.getPrivateVars();
2702 auto privateSyms = op.getPrivateSymsAttr();
2703
2704 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2705 return success();
2706
2707 auto numPrivateVars = privateVars.size();
2708 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2709
2710 if (numPrivateVars != numPrivateSyms)
2711 return op.emitError() << "inconsistent number of private variables and "
2712 "privatizer op symbols, private vars: "
2713 << numPrivateVars
2714 << " vs. privatizer op symbols: " << numPrivateSyms;
2715
2716 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2717 Type varType = std::get<0>(privateVarInfo).getType();
2718 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2719 PrivateClauseOp privatizerOp =
2721
2722 if (privatizerOp == nullptr)
2723 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2724 << privateSym << "'";
2725
2726 Type privatizerType = privatizerOp.getArgType();
2727
2728 if (privatizerType && (varType != privatizerType))
2729 return op.emitError()
2730 << "type mismatch between a "
2731 << (privatizerOp.getDataSharingType() ==
2732 DataSharingClauseType::Private
2733 ? "private"
2734 : "firstprivate")
2735 << " variable and its privatizer op, var type: " << varType
2736 << " vs. privatizer op type: " << privatizerType;
2737 }
2738
2739 return success();
2740}
2741
2742LogicalResult ParallelOp::verify() {
2743 if (getAllocateVars().size() != getAllocatorVars().size())
2744 return emitError(
2745 "expected equal sizes for allocate and allocator variables");
2746
2747 if (failed(verifyPrivateVarList(*this)))
2748 return failure();
2749
2750 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2751 getReductionByref());
2752}
2753
2754LogicalResult ParallelOp::verifyRegions() {
2755 auto distChildOps = getOps<DistributeOp>();
2756 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2757 if (numDistChildOps > 1)
2758 return emitError()
2759 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2760
2761 if (numDistChildOps == 1) {
2762 if (!isComposite())
2763 return emitError()
2764 << "'omp.composite' attribute missing from composite operation";
2765
2766 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2767 Operation &distributeOp = **distChildOps.begin();
2768 for (Operation &childOp : getOps()) {
2769 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2770 continue;
2771
2772 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2773 return emitError() << "unexpected OpenMP operation inside of composite "
2774 "'omp.parallel': "
2775 << childOp.getName();
2776 }
2777 } else if (isComposite()) {
2778 return emitError()
2779 << "'omp.composite' attribute present in non-composite operation";
2780 }
2781 return success();
2782}
2783
2784//===----------------------------------------------------------------------===//
2785// TeamsOp
2786//===----------------------------------------------------------------------===//
2787
2789 while ((op = op->getParentOp()))
2790 if (isa<OpenMPDialect>(op->getDialect()))
2791 return false;
2792 return true;
2793}
2794
2795void TeamsOp::build(OpBuilder &builder, OperationState &state,
2796 const TeamsOperands &clauses) {
2797 MLIRContext *ctx = builder.getContext();
2798 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2799 TeamsOp::build(
2800 builder, state, clauses.allocateVars, clauses.allocatorVars,
2801 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2802 /*private_vars=*/{}, /*private_syms=*/nullptr,
2803 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2804 clauses.reductionVars,
2805 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2806 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2807}
2808
2809// Verify num_teams clause
2810static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
2811 OperandRange numTeamsUpperVars) {
2812 // If lower is specified, upper must have exactly one value
2813 if (numTeamsLower) {
2814 if (numTeamsUpperVars.size() != 1)
2815 return op->emitError(
2816 "expected exactly one num_teams upper bound when lower bound is "
2817 "specified");
2818 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
2819 return op->emitError(
2820 "expected num_teams upper bound and lower bound to be "
2821 "the same type");
2822 }
2823
2824 return success();
2825}
2826
2827LogicalResult TeamsOp::verify() {
2828 // Check parent region
2829 // TODO If nested inside of a target region, also check that it does not
2830 // contain any statements, declarations or directives other than this
2831 // omp.teams construct. The issue is how to support the initialization of
2832 // this operation's own arguments (allow SSA values across omp.target?).
2833 Operation *op = getOperation();
2834 if (!isa<TargetOp>(op->getParentOp()) &&
2836 return emitError("expected to be nested inside of omp.target or not nested "
2837 "in any OpenMP dialect operations");
2838
2839 // Check for num_teams clause restrictions
2840 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
2841 this->getNumTeamsUpperVars())))
2842 return failure();
2843
2844 // Check for allocate clause restrictions
2845 if (getAllocateVars().size() != getAllocatorVars().size())
2846 return emitError(
2847 "expected equal sizes for allocate and allocator variables");
2848
2849 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2850 getReductionByref());
2851}
2852
2853//===----------------------------------------------------------------------===//
2854// SectionOp
2855//===----------------------------------------------------------------------===//
2856
2857OperandRange SectionOp::getPrivateVars() {
2858 return getParentOp().getPrivateVars();
2859}
2860
2861OperandRange SectionOp::getReductionVars() {
2862 return getParentOp().getReductionVars();
2863}
2864
2865//===----------------------------------------------------------------------===//
2866// SectionsOp
2867//===----------------------------------------------------------------------===//
2868
2869void SectionsOp::build(OpBuilder &builder, OperationState &state,
2870 const SectionsOperands &clauses) {
2871 MLIRContext *ctx = builder.getContext();
2872 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2873 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2874 clauses.nowait, /*private_vars=*/{},
2875 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2876 clauses.reductionMod, clauses.reductionVars,
2877 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2878 makeArrayAttr(ctx, clauses.reductionSyms));
2879}
2880
2881LogicalResult SectionsOp::verify() {
2882 if (getAllocateVars().size() != getAllocatorVars().size())
2883 return emitError(
2884 "expected equal sizes for allocate and allocator variables");
2885
2886 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2887 getReductionByref());
2888}
2889
2890LogicalResult SectionsOp::verifyRegions() {
2891 for (auto &inst : *getRegion().begin()) {
2892 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2893 return emitOpError()
2894 << "expected omp.section op or terminator op inside region";
2895 }
2896 }
2897
2898 return success();
2899}
2900
2901//===----------------------------------------------------------------------===//
2902// SingleOp
2903//===----------------------------------------------------------------------===//
2904
2905void SingleOp::build(OpBuilder &builder, OperationState &state,
2906 const SingleOperands &clauses) {
2907 MLIRContext *ctx = builder.getContext();
2908 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2909 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2910 clauses.copyprivateVars,
2911 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2912 /*private_vars=*/{}, /*private_syms=*/nullptr,
2913 /*private_needs_barrier=*/nullptr);
2914}
2915
2916LogicalResult SingleOp::verify() {
2917 // Check for allocate clause restrictions
2918 if (getAllocateVars().size() != getAllocatorVars().size())
2919 return emitError(
2920 "expected equal sizes for allocate and allocator variables");
2921
2922 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2923 getCopyprivateSyms());
2924}
2925
2926//===----------------------------------------------------------------------===//
2927// WorkshareOp
2928//===----------------------------------------------------------------------===//
2929
2930void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2931 const WorkshareOperands &clauses) {
2932 WorkshareOp::build(builder, state, clauses.nowait);
2933}
2934
2935//===----------------------------------------------------------------------===//
2936// WorkshareLoopWrapperOp
2937//===----------------------------------------------------------------------===//
2938
2939LogicalResult WorkshareLoopWrapperOp::verify() {
2940 if (!(*this)->getParentOfType<WorkshareOp>())
2941 return emitOpError() << "must be nested in an omp.workshare";
2942 return success();
2943}
2944
2945LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2946 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2947 getNestedWrapper())
2948 return emitOpError() << "expected to be a standalone loop wrapper";
2949
2950 return success();
2951}
2952
2953//===----------------------------------------------------------------------===//
2954// LoopWrapperInterface
2955//===----------------------------------------------------------------------===//
2956
2957LogicalResult LoopWrapperInterface::verifyImpl() {
2958 Operation *op = this->getOperation();
2959 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2961 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2962 "and `SingleBlock` traits";
2963
2964 if (op->getNumRegions() != 1)
2965 return emitOpError() << "loop wrapper does not contain exactly one region";
2966
2967 Region &region = op->getRegion(0);
2968 if (range_size(region.getOps()) != 1)
2969 return emitOpError()
2970 << "loop wrapper does not contain exactly one nested op";
2971
2972 Operation &firstOp = *region.op_begin();
2973 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2974 return emitOpError() << "nested in loop wrapper is not another loop "
2975 "wrapper or `omp.loop_nest`";
2976
2977 return success();
2978}
2979
2980//===----------------------------------------------------------------------===//
2981// LoopOp
2982//===----------------------------------------------------------------------===//
2983
2984void LoopOp::build(OpBuilder &builder, OperationState &state,
2985 const LoopOperands &clauses) {
2986 MLIRContext *ctx = builder.getContext();
2987
2988 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2989 makeArrayAttr(ctx, clauses.privateSyms),
2990 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2991 clauses.reductionMod, clauses.reductionVars,
2992 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2993 makeArrayAttr(ctx, clauses.reductionSyms));
2994}
2995
2996LogicalResult LoopOp::verify() {
2997 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2998 getReductionByref());
2999}
3000
3001LogicalResult LoopOp::verifyRegions() {
3002 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3003 getNestedWrapper())
3004 return emitOpError() << "expected to be a standalone loop wrapper";
3005
3006 return success();
3007}
3008
3009//===----------------------------------------------------------------------===//
3010// WsloopOp
3011//===----------------------------------------------------------------------===//
3012
3013void WsloopOp::build(OpBuilder &builder, OperationState &state,
3014 ArrayRef<NamedAttribute> attributes) {
3015 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
3016 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
3017 /*linear_var_types*/ nullptr, /*linear_modifiers=*/nullptr,
3018 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
3019 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
3020 /*private_needs_barrier=*/false,
3021 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
3022 /*reduction_byref=*/nullptr,
3023 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
3024 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
3025 /*schedule_simd=*/false);
3026 state.addAttributes(attributes);
3027}
3028
3029void WsloopOp::build(OpBuilder &builder, OperationState &state,
3030 const WsloopOperands &clauses) {
3031 MLIRContext *ctx = builder.getContext();
3032 // TODO: Store clauses in op: allocateVars, allocatorVars
3033 WsloopOp::build(
3034 builder, state,
3035 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
3036 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3037 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3038 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3039 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3040 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3041 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3042 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3043}
3044
3045LogicalResult WsloopOp::verify() {
3046 if (failed(
3047 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3048 return failure();
3049 if (getLinearVars().size() &&
3050 getLinearVarTypes().value().size() != getLinearVars().size())
3051 return emitError() << "Ill-formed type attributes for linear variables";
3052 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3053 getReductionByref());
3054}
3055
3056LogicalResult WsloopOp::verifyRegions() {
3057 bool isCompositeChildLeaf =
3058 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3059
3060 if (LoopWrapperInterface nested = getNestedWrapper()) {
3061 if (!isComposite())
3062 return emitError()
3063 << "'omp.composite' attribute missing from composite wrapper";
3064
3065 // Check for the allowed leaf constructs that may appear in a composite
3066 // construct directly after DO/FOR.
3067 if (!isa<SimdOp>(nested))
3068 return emitError() << "only supported nested wrapper is 'omp.simd'";
3069
3070 } else if (isComposite() && !isCompositeChildLeaf) {
3071 return emitError()
3072 << "'omp.composite' attribute present in non-composite wrapper";
3073 } else if (!isComposite() && isCompositeChildLeaf) {
3074 return emitError()
3075 << "'omp.composite' attribute missing from composite wrapper";
3076 }
3077
3078 return success();
3079}
3080
3081//===----------------------------------------------------------------------===//
3082// Simd construct [2.9.3.1]
3083//===----------------------------------------------------------------------===//
3084
3085void SimdOp::build(OpBuilder &builder, OperationState &state,
3086 const SimdOperands &clauses) {
3087 MLIRContext *ctx = builder.getContext();
3088 SimdOp::build(builder, state, clauses.alignedVars,
3089 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
3090 clauses.linearVars, clauses.linearStepVars,
3091 clauses.linearVarTypes, clauses.linearModifiers,
3092 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3093 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3094 clauses.privateNeedsBarrier, clauses.reductionMod,
3095 clauses.reductionVars,
3096 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3097 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
3098 clauses.simdlen);
3099}
3100
3101LogicalResult SimdOp::verify() {
3102 if (getSimdlen().has_value() && getSafelen().has_value() &&
3103 getSimdlen().value() > getSafelen().value())
3104 return emitOpError()
3105 << "simdlen clause and safelen clause are both present, but the "
3106 "simdlen value is not less than or equal to safelen value";
3107
3108 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
3109 return failure();
3110
3111 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
3112 return failure();
3113
3114 if (failed(
3115 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3116 return failure();
3117
3118 bool isCompositeChildLeaf =
3119 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3120
3121 if (!isComposite() && isCompositeChildLeaf)
3122 return emitError()
3123 << "'omp.composite' attribute missing from composite wrapper";
3124
3125 if (isComposite() && !isCompositeChildLeaf)
3126 return emitError()
3127 << "'omp.composite' attribute present in non-composite wrapper";
3128
3129 // Firstprivate is not allowed for SIMD in the standard. Check that none of
3130 // the private decls are for firstprivate.
3131 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3132 if (privateSyms) {
3133 for (const Attribute &sym : *privateSyms) {
3134 auto symRef = cast<SymbolRefAttr>(sym);
3135 omp::PrivateClauseOp privatizer =
3137 getOperation(), symRef);
3138 if (!privatizer)
3139 return emitError() << "Cannot find privatizer '" << symRef << "'";
3140 if (privatizer.getDataSharingType() ==
3141 DataSharingClauseType::FirstPrivate)
3142 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
3143 }
3144 }
3145
3146 if (getLinearVars().size() &&
3147 getLinearVarTypes().value().size() != getLinearVars().size())
3148 return emitError() << "Ill-formed type attributes for linear variables";
3149 return success();
3150}
3151
3152LogicalResult SimdOp::verifyRegions() {
3153 if (getNestedWrapper())
3154 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
3155
3156 return success();
3157}
3158
3159//===----------------------------------------------------------------------===//
3160// Distribute construct [2.9.4.1]
3161//===----------------------------------------------------------------------===//
3162
3163void DistributeOp::build(OpBuilder &builder, OperationState &state,
3164 const DistributeOperands &clauses) {
3165 DistributeOp::build(builder, state, clauses.allocateVars,
3166 clauses.allocatorVars, clauses.distScheduleStatic,
3167 clauses.distScheduleChunkSize, clauses.order,
3168 clauses.orderMod, clauses.privateVars,
3169 makeArrayAttr(builder.getContext(), clauses.privateSyms),
3170 clauses.privateNeedsBarrier);
3171}
3172
3173LogicalResult DistributeOp::verify() {
3174 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3175 return emitOpError() << "chunk size set without "
3176 "dist_schedule_static being present";
3177
3178 if (getAllocateVars().size() != getAllocatorVars().size())
3179 return emitError(
3180 "expected equal sizes for allocate and allocator variables");
3181
3182 return success();
3183}
3184
3185LogicalResult DistributeOp::verifyRegions() {
3186 if (LoopWrapperInterface nested = getNestedWrapper()) {
3187 if (!isComposite())
3188 return emitError()
3189 << "'omp.composite' attribute missing from composite wrapper";
3190 // Check for the allowed leaf constructs that may appear in a composite
3191 // construct directly after DISTRIBUTE.
3192 if (isa<WsloopOp>(nested)) {
3193 Operation *parentOp = (*this)->getParentOp();
3194 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3195 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3196 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3197 "when a composite 'omp.parallel' is the direct "
3198 "parent";
3199 }
3200 } else if (!isa<SimdOp>(nested))
3201 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3202 "'omp.wsloop'";
3203 } else if (isComposite()) {
3204 return emitError()
3205 << "'omp.composite' attribute present in non-composite wrapper";
3206 }
3207
3208 return success();
3209}
3210
3211//===----------------------------------------------------------------------===//
3212// DeclareMapperOp / DeclareMapperInfoOp
3213//===----------------------------------------------------------------------===//
3214
3215LogicalResult DeclareMapperInfoOp::verify() {
3216 return verifyMapClause(*this, getMapVars());
3217}
3218
3219LogicalResult DeclareMapperOp::verifyRegions() {
3220 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3221 getRegion().getBlocks().front().getTerminator()))
3222 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3223
3224 return success();
3225}
3226
3227//===----------------------------------------------------------------------===//
3228// DeclareReductionOp
3229//===----------------------------------------------------------------------===//
3230
3231LogicalResult DeclareReductionOp::verifyRegions() {
3232 if (!getAllocRegion().empty()) {
3233 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3234 if (yieldOp.getResults().size() != 1 ||
3235 yieldOp.getResults().getTypes()[0] != getType())
3236 return emitOpError() << "expects alloc region to yield a value "
3237 "of the reduction type";
3238 }
3239 }
3240
3241 if (getInitializerRegion().empty())
3242 return emitOpError() << "expects non-empty initializer region";
3243 Block &initializerEntryBlock = getInitializerRegion().front();
3244
3245 if (initializerEntryBlock.getNumArguments() == 1) {
3246 if (!getAllocRegion().empty())
3247 return emitOpError() << "expects two arguments to the initializer region "
3248 "when an allocation region is used";
3249 } else if (initializerEntryBlock.getNumArguments() == 2) {
3250 if (getAllocRegion().empty())
3251 return emitOpError() << "expects one argument to the initializer region "
3252 "when no allocation region is used";
3253 } else {
3254 return emitOpError()
3255 << "expects one or two arguments to the initializer region";
3256 }
3257
3258 for (mlir::Value arg : initializerEntryBlock.getArguments())
3259 if (arg.getType() != getType())
3260 return emitOpError() << "expects initializer region argument to match "
3261 "the reduction type";
3262
3263 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3264 if (yieldOp.getResults().size() != 1 ||
3265 yieldOp.getResults().getTypes()[0] != getType())
3266 return emitOpError() << "expects initializer region to yield a value "
3267 "of the reduction type";
3268 }
3269
3270 if (getReductionRegion().empty())
3271 return emitOpError() << "expects non-empty reduction region";
3272 Block &reductionEntryBlock = getReductionRegion().front();
3273 if (reductionEntryBlock.getNumArguments() != 2 ||
3274 reductionEntryBlock.getArgumentTypes()[0] !=
3275 reductionEntryBlock.getArgumentTypes()[1] ||
3276 reductionEntryBlock.getArgumentTypes()[0] != getType())
3277 return emitOpError() << "expects reduction region with two arguments of "
3278 "the reduction type";
3279 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3280 if (yieldOp.getResults().size() != 1 ||
3281 yieldOp.getResults().getTypes()[0] != getType())
3282 return emitOpError() << "expects reduction region to yield a value "
3283 "of the reduction type";
3284 }
3285
3286 if (!getAtomicReductionRegion().empty()) {
3287 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3288 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3289 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3290 atomicReductionEntryBlock.getArgumentTypes()[1])
3291 return emitOpError() << "expects atomic reduction region with two "
3292 "arguments of the same type";
3293 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3294 atomicReductionEntryBlock.getArgumentTypes()[0]);
3295 if (!ptrType ||
3296 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3297 return emitOpError() << "expects atomic reduction region arguments to "
3298 "be accumulators containing the reduction type";
3299 }
3300
3301 if (getCleanupRegion().empty())
3302 return success();
3303 Block &cleanupEntryBlock = getCleanupRegion().front();
3304 if (cleanupEntryBlock.getNumArguments() != 1 ||
3305 cleanupEntryBlock.getArgument(0).getType() != getType())
3306 return emitOpError() << "expects cleanup region with one argument "
3307 "of the reduction type";
3308
3309 return success();
3310}
3311
3312//===----------------------------------------------------------------------===//
3313// TaskOp
3314//===----------------------------------------------------------------------===//
3315
3316void TaskOp::build(OpBuilder &builder, OperationState &state,
3317 const TaskOperands &clauses) {
3318 MLIRContext *ctx = builder.getContext();
3319 TaskOp::build(
3320 builder, state, clauses.iterated, clauses.affinityVars,
3321 clauses.allocateVars, clauses.allocatorVars,
3322 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3323 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3324 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3325 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3326 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3327 clauses.priority, /*private_vars=*/clauses.privateVars,
3328 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3329 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3330}
3331
3332LogicalResult TaskOp::verify() {
3333 LogicalResult verifyDependVars =
3334 verifyDependVarList(*this, getDependKinds(), getDependVars(),
3335 getDependIteratedKinds(), getDependIterated());
3336 return failed(verifyDependVars)
3337 ? verifyDependVars
3338 : verifyReductionVarList(*this, getInReductionSyms(),
3339 getInReductionVars(),
3340 getInReductionByref());
3341}
3342
3343//===----------------------------------------------------------------------===//
3344// TaskgroupOp
3345//===----------------------------------------------------------------------===//
3346
3347void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3348 const TaskgroupOperands &clauses) {
3349 MLIRContext *ctx = builder.getContext();
3350 TaskgroupOp::build(builder, state, clauses.allocateVars,
3351 clauses.allocatorVars, clauses.taskReductionVars,
3352 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3353 makeArrayAttr(ctx, clauses.taskReductionSyms));
3354}
3355
3356LogicalResult TaskgroupOp::verify() {
3357 return verifyReductionVarList(*this, getTaskReductionSyms(),
3358 getTaskReductionVars(),
3359 getTaskReductionByref());
3360}
3361
3362//===----------------------------------------------------------------------===//
3363// TaskloopOp
3364//===----------------------------------------------------------------------===//
3365
3366void TaskloopOp::build(OpBuilder &builder, OperationState &state,
3367 const TaskloopOperands &clauses) {
3368 MLIRContext *ctx = builder.getContext();
3369 TaskloopOp::build(
3370 builder, state, clauses.allocateVars, clauses.allocatorVars,
3371 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3372 clauses.inReductionVars,
3373 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3374 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3375 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3376 /*private_vars=*/clauses.privateVars,
3377 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3378 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3379 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3380 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3381}
3382
3383LogicalResult TaskloopOp::verify() {
3384 if (getAllocateVars().size() != getAllocatorVars().size())
3385 return emitError(
3386 "expected equal sizes for allocate and allocator variables");
3387 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3388 getReductionVars(), getReductionByref())) ||
3389 failed(verifyReductionVarList(*this, getInReductionSyms(),
3390 getInReductionVars(),
3391 getInReductionByref())))
3392 return failure();
3393
3394 if (!getReductionVars().empty() && getNogroup())
3395 return emitError("if a reduction clause is present on the taskloop "
3396 "directive, the nogroup clause must not be specified");
3397 for (auto var : getReductionVars()) {
3398 if (llvm::is_contained(getInReductionVars(), var))
3399 return emitError("the same list item cannot appear in both a reduction "
3400 "and an in_reduction clause");
3401 }
3402
3403 if (getGrainsize() && getNumTasks()) {
3404 return emitError(
3405 "the grainsize clause and num_tasks clause are mutually exclusive and "
3406 "may not appear on the same taskloop directive");
3407 }
3408
3409 return success();
3410}
3411
3412LogicalResult TaskloopOp::verifyRegions() {
3413 if (LoopWrapperInterface nested = getNestedWrapper()) {
3414 if (!isComposite())
3415 return emitError()
3416 << "'omp.composite' attribute missing from composite wrapper";
3417
3418 // Check for the allowed leaf constructs that may appear in a composite
3419 // construct directly after TASKLOOP.
3420 if (!isa<SimdOp>(nested))
3421 return emitError() << "only supported nested wrapper is 'omp.simd'";
3422 } else if (isComposite()) {
3423 return emitError()
3424 << "'omp.composite' attribute present in non-composite wrapper";
3425 }
3426
3427 return success();
3428}
3429
3430//===----------------------------------------------------------------------===//
3431// LoopNestOp
3432//===----------------------------------------------------------------------===//
3433
3434ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3435 // Parse an opening `(` followed by induction variables followed by `)`
3438 Type loopVarType;
3440 parser.parseColonType(loopVarType) ||
3441 // Parse loop bounds.
3442 parser.parseEqual() ||
3443 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3444 parser.parseKeyword("to") ||
3445 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3446 return failure();
3447
3448 for (auto &iv : ivs)
3449 iv.type = loopVarType;
3450
3451 auto *ctx = parser.getBuilder().getContext();
3452 // Parse "inclusive" flag.
3453 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3454 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3455
3456 // Parse step values.
3458 if (parser.parseKeyword("step") ||
3459 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3460 return failure();
3461
3462 // Parse collapse
3463 int64_t value = 0;
3464 if (!parser.parseOptionalKeyword("collapse") &&
3465 (parser.parseLParen() || parser.parseInteger(value) ||
3466 parser.parseRParen()))
3467 return failure();
3468 if (value > 1)
3469 result.addAttribute(
3470 "collapse_num_loops",
3471 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3472
3473 // Parse tiles
3475 auto parseTiles = [&]() -> ParseResult {
3476 int64_t tile;
3477 if (parser.parseInteger(tile))
3478 return failure();
3479 tiles.push_back(tile);
3480 return success();
3481 };
3482
3483 if (!parser.parseOptionalKeyword("tiles") &&
3484 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3485 parser.parseRParen()))
3486 return failure();
3487
3488 if (tiles.size() > 0)
3489 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3490
3491 // Parse the body.
3492 Region *region = result.addRegion();
3493 if (parser.parseRegion(*region, ivs))
3494 return failure();
3495
3496 // Resolve operands.
3497 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3498 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3499 parser.resolveOperands(steps, loopVarType, result.operands))
3500 return failure();
3501
3502 // Parse the optional attribute list.
3503 return parser.parseOptionalAttrDict(result.attributes);
3504}
3505
3506void LoopNestOp::print(OpAsmPrinter &p) {
3507 Region &region = getRegion();
3508 auto args = region.getArguments();
3509 p << " (" << args << ") : " << args[0].getType() << " = ("
3510 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3511 if (getLoopInclusive())
3512 p << "inclusive ";
3513 p << "step (" << getLoopSteps() << ") ";
3514 if (int64_t numCollapse = getCollapseNumLoops())
3515 if (numCollapse > 1)
3516 p << "collapse(" << numCollapse << ") ";
3517
3518 if (const auto tiles = getTileSizes())
3519 p << "tiles(" << tiles.value() << ") ";
3520
3521 p.printRegion(region, /*printEntryBlockArgs=*/false);
3522}
3523
3524void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3525 const LoopNestOperands &clauses) {
3526 MLIRContext *ctx = builder.getContext();
3527 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3528 clauses.loopLowerBounds, clauses.loopUpperBounds,
3529 clauses.loopSteps, clauses.loopInclusive,
3530 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3531}
3532
3533LogicalResult LoopNestOp::verify() {
3534 if (getLoopLowerBounds().empty())
3535 return emitOpError() << "must represent at least one loop";
3536
3537 if (getLoopLowerBounds().size() != getIVs().size())
3538 return emitOpError() << "number of range arguments and IVs do not match";
3539
3540 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3541 if (lb.getType() != iv.getType())
3542 return emitOpError()
3543 << "range argument type does not match corresponding IV type";
3544 }
3545
3546 uint64_t numIVs = getIVs().size();
3547
3548 if (const auto &numCollapse = getCollapseNumLoops())
3549 if (numCollapse > numIVs)
3550 return emitOpError()
3551 << "collapse value is larger than the number of loops";
3552
3553 if (const auto &tiles = getTileSizes())
3554 if (tiles.value().size() > numIVs)
3555 return emitOpError() << "too few canonical loops for tile dimensions";
3556
3557 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3558 return emitOpError() << "expects parent op to be a loop wrapper";
3559
3560 return success();
3561}
3562
3563void LoopNestOp::gatherWrappers(
3565 Operation *parent = (*this)->getParentOp();
3566 while (auto wrapper =
3567 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3568 wrappers.push_back(wrapper);
3569 parent = parent->getParentOp();
3570 }
3571}
3572
3573//===----------------------------------------------------------------------===//
3574// OpenMP canonical loop handling
3575//===----------------------------------------------------------------------===//
3576
3577std::tuple<NewCliOp, OpOperand *, OpOperand *>
3578mlir::omp ::decodeCli(Value cli) {
3579
3580 // Defining a CLI for a generated loop is optional; if there is none then
3581 // there is no followup-tranformation
3582 if (!cli)
3583 return {{}, nullptr, nullptr};
3584
3585 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3586 "Unexpected type of cli");
3587
3588 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3589 OpOperand *gen = nullptr;
3590 OpOperand *cons = nullptr;
3591 for (OpOperand &use : cli.getUses()) {
3592 auto op = cast<LoopTransformationInterface>(use.getOwner());
3593
3594 unsigned opnum = use.getOperandNumber();
3595 if (op.isGeneratee(opnum)) {
3596 assert(!gen && "Each CLI may have at most one def");
3597 gen = &use;
3598 } else if (op.isApplyee(opnum)) {
3599 assert(!cons && "Each CLI may have at most one consumer");
3600 cons = &use;
3601 } else {
3602 llvm_unreachable("Unexpected operand for a CLI");
3603 }
3604 }
3605
3606 return {create, gen, cons};
3607}
3608
3609void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3610 ::mlir::OperationState &odsState) {
3611 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3612}
3613
3614void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3615 Value result = getResult();
3616 auto [newCli, gen, cons] = decodeCli(result);
3617
3618 // Structured binding `gen` cannot be captured in lambdas before C++20
3619 OpOperand *generator = gen;
3620
3621 // Derive the CLI variable name from its generator:
3622 // * "canonloop" for omp.canonical_loop
3623 // * custom name for loop transformation generatees
3624 // * "cli" as fallback if no generator
3625 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3626 // at that level
3627 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3628 // the index of that region
3629 std::string cliName{"cli"};
3630 if (gen) {
3631 cliName =
3633 .Case([&](CanonicalLoopOp op) {
3634 return generateLoopNestingName("canonloop", op);
3635 })
3636 .Case([&](UnrollHeuristicOp op) -> std::string {
3637 llvm_unreachable("heuristic unrolling does not generate a loop");
3638 })
3639 .Case([&](FuseOp op) -> std::string {
3640 unsigned opnum = generator->getOperandNumber();
3641 // The position of the first loop to be fused is the same position
3642 // as the resulting fused loop
3643 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3644 return "canonloop_fuse";
3645 else
3646 return "fused";
3647 })
3648 .Case([&](TileOp op) -> std::string {
3649 auto [generateesFirst, generateesCount] =
3650 op.getGenerateesODSOperandIndexAndLength();
3651 unsigned firstGrid = generateesFirst;
3652 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3653 unsigned end = generateesFirst + generateesCount;
3654 unsigned opnum = generator->getOperandNumber();
3655 // In the OpenMP apply and looprange clauses, indices are 1-based
3656 if (firstGrid <= opnum && opnum < firstIntratile) {
3657 unsigned gridnum = opnum - firstGrid + 1;
3658 return ("grid" + Twine(gridnum)).str();
3659 }
3660 if (firstIntratile <= opnum && opnum < end) {
3661 unsigned intratilenum = opnum - firstIntratile + 1;
3662 return ("intratile" + Twine(intratilenum)).str();
3663 }
3664 llvm_unreachable("Unexpected generatee argument");
3665 })
3666 .DefaultUnreachable("TODO: Custom name for this operation");
3667 }
3668
3669 setNameFn(result, cliName);
3670}
3671
3672LogicalResult NewCliOp::verify() {
3673 Value cli = getResult();
3674
3675 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3676 "Unexpected type of cli");
3677
3678 // Check that the CLI is used in at most generator and one consumer
3679 OpOperand *gen = nullptr;
3680 OpOperand *cons = nullptr;
3681 for (mlir::OpOperand &use : cli.getUses()) {
3682 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3683
3684 unsigned opnum = use.getOperandNumber();
3685 if (op.isGeneratee(opnum)) {
3686 if (gen) {
3687 InFlightDiagnostic error =
3688 emitOpError("CLI must have at most one generator");
3689 error.attachNote(gen->getOwner()->getLoc())
3690 .append("first generator here:");
3691 error.attachNote(use.getOwner()->getLoc())
3692 .append("second generator here:");
3693 return error;
3694 }
3695
3696 gen = &use;
3697 } else if (op.isApplyee(opnum)) {
3698 if (cons) {
3699 InFlightDiagnostic error =
3700 emitOpError("CLI must have at most one consumer");
3701 error.attachNote(cons->getOwner()->getLoc())
3702 .append("first consumer here:")
3703 .appendOp(*cons->getOwner(),
3704 OpPrintingFlags().printGenericOpForm());
3705 error.attachNote(use.getOwner()->getLoc())
3706 .append("second consumer here:")
3707 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3708 return error;
3709 }
3710
3711 cons = &use;
3712 } else {
3713 llvm_unreachable("Unexpected operand for a CLI");
3714 }
3715 }
3716
3717 // If the CLI is source of a transformation, it must have a generator
3718 if (cons && !gen) {
3719 InFlightDiagnostic error = emitOpError("CLI has no generator");
3720 error.attachNote(cons->getOwner()->getLoc())
3721 .append("see consumer here: ")
3722 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3723 return error;
3724 }
3725
3726 return success();
3727}
3728
3729void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3730 Value tripCount) {
3731 odsState.addOperands(tripCount);
3732 odsState.addOperands(Value());
3733 (void)odsState.addRegion();
3734}
3735
3736void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3737 Value tripCount, ::mlir::Value cli) {
3738 odsState.addOperands(tripCount);
3739 odsState.addOperands(cli);
3740 (void)odsState.addRegion();
3741}
3742
3743void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3744 setNameFn(&getRegion().front(), "body_entry");
3745}
3746
3747void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3748 OpAsmSetValueNameFn setNameFn) {
3749 std::string ivName = generateLoopNestingName("iv", *this);
3750 setNameFn(region.getArgument(0), ivName);
3751}
3752
3753void CanonicalLoopOp::print(OpAsmPrinter &p) {
3754 if (getCli())
3755 p << '(' << getCli() << ')';
3756 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3757 << " in range(" << getTripCount() << ") ";
3758
3759 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3760 /*printBlockTerminators=*/true);
3761
3762 p.printOptionalAttrDict((*this)->getAttrs());
3763}
3764
3765mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3767 CanonicalLoopInfoType cliType =
3768 CanonicalLoopInfoType::get(parser.getContext());
3769
3770 // Parse (optional) omp.cli identifier
3772 SmallVector<mlir::Value, 1> cliOperand;
3773 if (!parser.parseOptionalLParen()) {
3774 if (parser.parseOperand(cli) ||
3775 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3776 return failure();
3777 }
3778
3779 // We derive the type of tripCount from inductionVariable. MLIR requires the
3780 // type of tripCount to be known when calling resolveOperand so we have parse
3781 // the type before processing the inductionVariable.
3782 OpAsmParser::Argument inductionVariable;
3784 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3785 parser.parseKeyword("in") || parser.parseKeyword("range") ||
3786 parser.parseLParen() || parser.parseOperand(tripcount) ||
3787 parser.parseRParen() ||
3788 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3789 return failure();
3790
3791 // Parse the loop body.
3792 Region *region = result.addRegion();
3793 if (parser.parseRegion(*region, {inductionVariable}))
3794 return failure();
3795
3796 // We parsed the cli operand forst, but because it is optional, it must be
3797 // last in the operand list.
3798 result.operands.append(cliOperand);
3799
3800 // Parse the optional attribute list.
3801 if (parser.parseOptionalAttrDict(result.attributes))
3802 return failure();
3803
3804 return mlir::success();
3805}
3806
3807LogicalResult CanonicalLoopOp::verify() {
3808 // The region's entry must accept the induction variable
3809 // It can also be empty if just created
3810 if (!getRegion().empty()) {
3811 Region &region = getRegion();
3812 if (region.getNumArguments() != 1)
3813 return emitOpError(
3814 "Canonical loop region must have exactly one argument");
3815
3816 if (getInductionVar().getType() != getTripCount().getType())
3817 return emitOpError(
3818 "Region argument must be the same type as the trip count");
3819 }
3820
3821 return success();
3822}
3823
3824Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3825
3826std::pair<unsigned, unsigned>
3827CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3828 // No applyees
3829 return {0, 0};
3830}
3831
3832std::pair<unsigned, unsigned>
3833CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3834 return getODSOperandIndexAndLength(odsIndex_cli);
3835}
3836
3837//===----------------------------------------------------------------------===//
3838// UnrollHeuristicOp
3839//===----------------------------------------------------------------------===//
3840
3841void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3842 ::mlir::OperationState &odsState,
3843 ::mlir::Value cli) {
3844 odsState.addOperands(cli);
3845}
3846
3847void UnrollHeuristicOp::print(OpAsmPrinter &p) {
3848 p << '(' << getApplyee() << ')';
3849
3850 p.printOptionalAttrDict((*this)->getAttrs());
3851}
3852
3853mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3855 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3856
3857 if (parser.parseLParen())
3858 return failure();
3859
3861 if (parser.parseOperand(applyee) ||
3862 parser.resolveOperand(applyee, cliType, result.operands))
3863 return failure();
3864
3865 if (parser.parseRParen())
3866 return failure();
3867
3868 // Optional output loop (full unrolling has none)
3869 if (!parser.parseOptionalArrow()) {
3870 if (parser.parseLParen() || parser.parseRParen())
3871 return failure();
3872 }
3873
3874 // Parse the optional attribute list.
3875 if (parser.parseOptionalAttrDict(result.attributes))
3876 return failure();
3877
3878 return mlir::success();
3879}
3880
3881std::pair<unsigned, unsigned>
3882UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3883 return getODSOperandIndexAndLength(odsIndex_applyee);
3884}
3885
3886std::pair<unsigned, unsigned>
3887UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3888 return {0, 0};
3889}
3890
3891//===----------------------------------------------------------------------===//
3892// TileOp
3893//===----------------------------------------------------------------------===//
3894
3895static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3896 OperandRange generatees,
3897 OperandRange applyees) {
3898 if (!generatees.empty())
3899 p << '(' << llvm::interleaved(generatees) << ')';
3900
3901 if (!applyees.empty())
3902 p << " <- (" << llvm::interleaved(applyees) << ')';
3903}
3904
3905static ParseResult parseLoopTransformClis(
3906 OpAsmParser &parser,
3909 if (parser.parseOptionalLess()) {
3910 // Syntax 1: generatees present
3911
3912 if (parser.parseOperandList(generateesOperands,
3914 return failure();
3915
3916 if (parser.parseLess())
3917 return failure();
3918 } else {
3919 // Syntax 2: generatees omitted
3920 }
3921
3922 // Parse `<-` (`<` has already been parsed)
3923 if (parser.parseMinus())
3924 return failure();
3925
3926 if (parser.parseOperandList(applyeesOperands,
3928 return failure();
3929
3930 return success();
3931}
3932
3933/// Check properties of the loop nest consisting of the transformation's
3934/// applyees:
3935/// 1. They are nested inside each other
3936/// 2. They are perfectly nested
3937/// (no code with side-effects in-between the loops)
3938/// 3. They are rectangular
3939/// (loop bounds are invariant in respect to the outer loops)
3940///
3941/// TODO: Generalize for LoopTransformationInterface.
3942static LogicalResult checkApplyeesNesting(TileOp op) {
3943 // Collect the loops from the nest
3944 bool isOnlyCanonLoops = true;
3946 for (Value applyee : op.getApplyees()) {
3947 auto [create, gen, cons] = decodeCli(applyee);
3948
3949 if (!gen)
3950 return op.emitOpError() << "applyee CLI has no generator";
3951
3952 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3953 canonLoops.push_back(loop);
3954 if (!loop)
3955 isOnlyCanonLoops = false;
3956 }
3957
3958 // FIXME: We currently can only verify non-rectangularity and perfect nest of
3959 // omp.canonical_loop.
3960 if (!isOnlyCanonLoops)
3961 return success();
3962
3963 DenseSet<Value> parentIVs;
3964 for (auto i : llvm::seq<int>(1, canonLoops.size())) {
3965 auto parentLoop = canonLoops[i - 1];
3966 auto loop = canonLoops[i];
3967
3968 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
3969 return op.emitOpError()
3970 << "tiled loop nest must be nested within each other";
3971
3972 parentIVs.insert(parentLoop.getInductionVar());
3973
3974 // Canonical loop must be perfectly nested, i.e. the body of the parent must
3975 // only contain the omp.canonical_loop of the nested loops, and
3976 // omp.terminator
3977 bool isPerfectlyNested = [&]() {
3978 auto &parentBody = parentLoop.getRegion();
3979 if (!parentBody.hasOneBlock())
3980 return false;
3981 auto &parentBlock = parentBody.getBlocks().front();
3982
3983 auto nestedLoopIt = parentBlock.begin();
3984 if (nestedLoopIt == parentBlock.end() ||
3985 (&*nestedLoopIt != loop.getOperation()))
3986 return false;
3987
3988 auto termIt = std::next(nestedLoopIt);
3989 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3990 return false;
3991
3992 if (std::next(termIt) != parentBlock.end())
3993 return false;
3994
3995 return true;
3996 }();
3997 if (!isPerfectlyNested)
3998 return op.emitOpError() << "tiled loop nest must be perfectly nested";
3999
4000 if (parentIVs.contains(loop.getTripCount()))
4001 return op.emitOpError() << "tiled loop nest must be rectangular";
4002 }
4003
4004 // TODO: The tile sizes must be computed before the loop, but checking this
4005 // requires dominance analysis. For instance:
4006 //
4007 // %canonloop = omp.new_cli
4008 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
4009 // // write to %x
4010 // omp.terminator
4011 // }
4012 // %ts = llvm.load %x
4013 // omp.tile <- (%canonloop) sizes(%ts : i32)
4014
4015 return success();
4016}
4017
4018LogicalResult TileOp::verify() {
4019 if (getApplyees().empty())
4020 return emitOpError() << "must apply to at least one loop";
4021
4022 if (getSizes().size() != getApplyees().size())
4023 return emitOpError() << "there must be one tile size for each applyee";
4024
4025 if (!getGeneratees().empty() &&
4026 2 * getSizes().size() != getGeneratees().size())
4027 return emitOpError()
4028 << "expecting two times the number of generatees than applyees";
4029
4030 return checkApplyeesNesting(*this);
4031}
4032
4033std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4034 return getODSOperandIndexAndLength(odsIndex_applyees);
4035}
4036
4037std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4038 return getODSOperandIndexAndLength(odsIndex_generatees);
4039}
4040
4041//===----------------------------------------------------------------------===//
4042// FuseOp
4043//===----------------------------------------------------------------------===//
4044
4045static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
4046 OperandRange generatees,
4047 OperandRange applyees) {
4048 if (!generatees.empty())
4049 p << '(' << llvm::interleaved(generatees) << ')';
4050
4051 if (!applyees.empty())
4052 p << " <- (" << llvm::interleaved(applyees) << ')';
4053}
4054
4055LogicalResult FuseOp::verify() {
4056 if (getApplyees().size() < 2)
4057 return emitOpError() << "must apply to at least two loops";
4058
4059 if (getFirst().has_value() && getCount().has_value()) {
4060 int64_t first = getFirst().value();
4061 int64_t count = getCount().value();
4062 if ((unsigned)(first + count - 1) > getApplyees().size())
4063 return emitOpError() << "the numbers of applyees must be at least first "
4064 "minus one plus count attributes";
4065 if (!getGeneratees().empty() &&
4066 getGeneratees().size() != getApplyees().size() + 1 - count)
4067 return emitOpError() << "the number of generatees must be the number of "
4068 "aplyees plus one minus count";
4069
4070 } else {
4071 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4072 return emitOpError()
4073 << "in a complete fuse the number of generatees must be exactly 1";
4074 }
4075 for (auto &&applyee : getApplyees()) {
4076 auto [create, gen, cons] = decodeCli(applyee);
4077
4078 if (!gen)
4079 return emitOpError() << "applyee CLI has no generator";
4080 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4081 if (!loop)
4082 return emitOpError()
4083 << "currently only supports omp.canonical_loop as applyee";
4084 }
4085 return success();
4086}
4087std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4088 return getODSOperandIndexAndLength(odsIndex_applyees);
4089}
4090
4091std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4092 return getODSOperandIndexAndLength(odsIndex_generatees);
4093}
4094
4095//===----------------------------------------------------------------------===//
4096// Critical construct (2.17.1)
4097//===----------------------------------------------------------------------===//
4098
4099void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
4100 const CriticalDeclareOperands &clauses) {
4101 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4102}
4103
4104LogicalResult CriticalDeclareOp::verify() {
4105 return verifySynchronizationHint(*this, getHint());
4106}
4107
4108LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4109 if (getNameAttr()) {
4110 SymbolRefAttr symbolRef = getNameAttr();
4111 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
4112 *this, symbolRef);
4113 if (!decl) {
4114 return emitOpError() << "expected symbol reference " << symbolRef
4115 << " to point to a critical declaration";
4116 }
4117 }
4118
4119 return success();
4120}
4121
4122//===----------------------------------------------------------------------===//
4123// Ordered construct
4124//===----------------------------------------------------------------------===//
4125
4126static LogicalResult verifyOrderedParent(Operation &op) {
4127 bool hasRegion = op.getNumRegions() > 0;
4128 auto loopOp = op.getParentOfType<LoopNestOp>();
4129 if (!loopOp) {
4130 if (hasRegion)
4131 return success();
4132
4133 // TODO: Consider if this needs to be the case only for the standalone
4134 // variant of the ordered construct.
4135 return op.emitOpError() << "must be nested inside of a loop";
4136 }
4137
4138 Operation *wrapper = loopOp->getParentOp();
4139 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4140 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4141 if (!orderedAttr)
4142 return op.emitOpError() << "the enclosing worksharing-loop region must "
4143 "have an ordered clause";
4144
4145 if (hasRegion && orderedAttr.getInt() != 0)
4146 return op.emitOpError() << "the enclosing loop's ordered clause must not "
4147 "have a parameter present";
4148
4149 if (!hasRegion && orderedAttr.getInt() == 0)
4150 return op.emitOpError() << "the enclosing loop's ordered clause must "
4151 "have a parameter present";
4152 } else if (!isa<SimdOp>(wrapper)) {
4153 return op.emitOpError() << "must be nested inside of a worksharing, simd "
4154 "or worksharing simd loop";
4155 }
4156 return success();
4157}
4158
4159void OrderedOp::build(OpBuilder &builder, OperationState &state,
4160 const OrderedOperands &clauses) {
4161 OrderedOp::build(builder, state, clauses.doacrossDependType,
4162 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4163}
4164
4165LogicalResult OrderedOp::verify() {
4166 if (failed(verifyOrderedParent(**this)))
4167 return failure();
4168
4169 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4170 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4171 return emitOpError() << "number of variables in depend clause does not "
4172 << "match number of iteration variables in the "
4173 << "doacross loop";
4174
4175 return success();
4176}
4177
4178void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
4179 const OrderedRegionOperands &clauses) {
4180 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4181}
4182
4183LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
4184
4185//===----------------------------------------------------------------------===//
4186// TaskwaitOp
4187//===----------------------------------------------------------------------===//
4188
4189void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
4190 const TaskwaitOperands &clauses) {
4191 // TODO Store clauses in op: dependKinds, dependVars, nowait.
4192 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
4193 /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr,
4194 /*depend_iterated=*/{}, /*nowait=*/nullptr);
4195}
4196
4197//===----------------------------------------------------------------------===//
4198// Verifier for AtomicReadOp
4199//===----------------------------------------------------------------------===//
4200
4201LogicalResult AtomicReadOp::verify() {
4202 if (verifyCommon().failed())
4203 return mlir::failure();
4204
4205 if (auto mo = getMemoryOrder()) {
4206 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4207 *mo == ClauseMemoryOrderKind::Release) {
4208 return emitError(
4209 "memory-order must not be acq_rel or release for atomic reads");
4210 }
4211 }
4212 return verifySynchronizationHint(*this, getHint());
4213}
4214
4215//===----------------------------------------------------------------------===//
4216// Verifier for AtomicWriteOp
4217//===----------------------------------------------------------------------===//
4218
4219LogicalResult AtomicWriteOp::verify() {
4220 if (verifyCommon().failed())
4221 return mlir::failure();
4222
4223 if (auto mo = getMemoryOrder()) {
4224 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4225 *mo == ClauseMemoryOrderKind::Acquire) {
4226 return emitError(
4227 "memory-order must not be acq_rel or acquire for atomic writes");
4228 }
4229 }
4230 return verifySynchronizationHint(*this, getHint());
4231}
4232
4233//===----------------------------------------------------------------------===//
4234// Verifier for AtomicUpdateOp
4235//===----------------------------------------------------------------------===//
4236
4237LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4238 PatternRewriter &rewriter) {
4239 if (op.isNoOp()) {
4240 rewriter.eraseOp(op);
4241 return success();
4242 }
4243 if (Value writeVal = op.getWriteOpVal()) {
4244 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4245 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4246 return success();
4247 }
4248 return failure();
4249}
4250
4251LogicalResult AtomicUpdateOp::verify() {
4252 if (verifyCommon().failed())
4253 return mlir::failure();
4254
4255 if (auto mo = getMemoryOrder()) {
4256 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4257 *mo == ClauseMemoryOrderKind::Acquire) {
4258 return emitError(
4259 "memory-order must not be acq_rel or acquire for atomic updates");
4260 }
4261 }
4262
4263 return verifySynchronizationHint(*this, getHint());
4264}
4265
4266LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4267
4268//===----------------------------------------------------------------------===//
4269// Verifier for AtomicCaptureOp
4270//===----------------------------------------------------------------------===//
4271
4272AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4273 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4274 return op;
4275 return dyn_cast<AtomicReadOp>(getSecondOp());
4276}
4277
4278AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4279 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4280 return op;
4281 return dyn_cast<AtomicWriteOp>(getSecondOp());
4282}
4283
4284AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4285 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4286 return op;
4287 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4288}
4289
4290LogicalResult AtomicCaptureOp::verify() {
4291 return verifySynchronizationHint(*this, getHint());
4292}
4293
4294LogicalResult AtomicCaptureOp::verifyRegions() {
4295 if (verifyRegionsCommon().failed())
4296 return mlir::failure();
4297
4298 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4299 return emitOpError(
4300 "operations inside capture region must not have hint clause");
4301
4302 if (getFirstOp()->getAttr("memory_order") ||
4303 getSecondOp()->getAttr("memory_order"))
4304 return emitOpError(
4305 "operations inside capture region must not have memory_order clause");
4306 return success();
4307}
4308
4309//===----------------------------------------------------------------------===//
4310// CancelOp
4311//===----------------------------------------------------------------------===//
4312
4313void CancelOp::build(OpBuilder &builder, OperationState &state,
4314 const CancelOperands &clauses) {
4315 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4316}
4317
4319 Operation *parent = thisOp->getParentOp();
4320 while (parent) {
4321 if (parent->getDialect() == thisOp->getDialect())
4322 return parent;
4323 parent = parent->getParentOp();
4324 }
4325 return nullptr;
4326}
4327
4328LogicalResult CancelOp::verify() {
4329 ClauseCancellationConstructType cct = getCancelDirective();
4330 // The next OpenMP operation in the chain of parents
4331 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4332 if (!structuralParent)
4333 return emitOpError() << "Orphaned cancel construct";
4334
4335 if ((cct == ClauseCancellationConstructType::Parallel) &&
4336 !mlir::isa<ParallelOp>(structuralParent)) {
4337 return emitOpError() << "cancel parallel must appear "
4338 << "inside a parallel region";
4339 }
4340 if (cct == ClauseCancellationConstructType::Loop) {
4341 // structural parent will be omp.loop_nest, directly nested inside
4342 // omp.wsloop
4343 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4344
4345 if (!wsloopOp) {
4346 return emitOpError()
4347 << "cancel loop must appear inside a worksharing-loop region";
4348 }
4349 if (wsloopOp.getNowaitAttr()) {
4350 return emitError() << "A worksharing construct that is canceled "
4351 << "must not have a nowait clause";
4352 }
4353 if (wsloopOp.getOrderedAttr()) {
4354 return emitError() << "A worksharing construct that is canceled "
4355 << "must not have an ordered clause";
4356 }
4357
4358 } else if (cct == ClauseCancellationConstructType::Sections) {
4359 // structural parent will be an omp.section, directly nested inside
4360 // omp.sections
4361 auto sectionsOp =
4362 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4363 if (!sectionsOp) {
4364 return emitOpError() << "cancel sections must appear "
4365 << "inside a sections region";
4366 }
4367 if (sectionsOp.getNowait()) {
4368 return emitError() << "A sections construct that is canceled "
4369 << "must not have a nowait clause";
4370 }
4371 }
4372 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4373 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4374 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4375 return emitOpError() << "cancel taskgroup must appear "
4376 << "inside a task region";
4377 }
4378 return success();
4379}
4380
4381//===----------------------------------------------------------------------===//
4382// CancellationPointOp
4383//===----------------------------------------------------------------------===//
4384
4385void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4386 const CancellationPointOperands &clauses) {
4387 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4388}
4389
4390LogicalResult CancellationPointOp::verify() {
4391 ClauseCancellationConstructType cct = getCancelDirective();
4392 // The next OpenMP operation in the chain of parents
4393 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4394 if (!structuralParent)
4395 return emitOpError() << "Orphaned cancellation point";
4396
4397 if ((cct == ClauseCancellationConstructType::Parallel) &&
4398 !mlir::isa<ParallelOp>(structuralParent)) {
4399 return emitOpError() << "cancellation point parallel must appear "
4400 << "inside a parallel region";
4401 }
4402 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4403 // find the wsloop
4404 if ((cct == ClauseCancellationConstructType::Loop) &&
4405 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4406 return emitOpError() << "cancellation point loop must appear "
4407 << "inside a worksharing-loop region";
4408 }
4409 if ((cct == ClauseCancellationConstructType::Sections) &&
4410 !mlir::isa<omp::SectionOp>(structuralParent)) {
4411 return emitOpError() << "cancellation point sections must appear "
4412 << "inside a sections region";
4413 }
4414 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4415 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4416 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4417 return emitOpError() << "cancellation point taskgroup must appear "
4418 << "inside a task region";
4419 }
4420 return success();
4421}
4422
4423//===----------------------------------------------------------------------===//
4424// MapBoundsOp
4425//===----------------------------------------------------------------------===//
4426
4427LogicalResult MapBoundsOp::verify() {
4428 auto extent = getExtent();
4429 auto upperbound = getUpperBound();
4430 if (!extent && !upperbound)
4431 return emitError("expected extent or upperbound.");
4432 return success();
4433}
4434
4435void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4436 TypeRange /*result_types*/, StringAttr symName,
4437 TypeAttr type) {
4438 PrivateClauseOp::build(
4439 odsBuilder, odsState, symName, type,
4440 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4441 DataSharingClauseType::Private));
4442}
4443
4444LogicalResult PrivateClauseOp::verifyRegions() {
4445 Type argType = getArgType();
4446 auto verifyTerminator = [&](Operation *terminator,
4447 bool yieldsValue) -> LogicalResult {
4448 if (!terminator->getBlock()->getSuccessors().empty())
4449 return success();
4450
4451 if (!llvm::isa<YieldOp>(terminator))
4452 return mlir::emitError(terminator->getLoc())
4453 << "expected exit block terminator to be an `omp.yield` op.";
4454
4455 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4456 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4457
4458 if (!yieldsValue) {
4459 if (yieldedTypes.empty())
4460 return success();
4461
4462 return mlir::emitError(terminator->getLoc())
4463 << "Did not expect any values to be yielded.";
4464 }
4465
4466 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4467 return success();
4468
4469 auto error = mlir::emitError(yieldOp.getLoc())
4470 << "Invalid yielded value. Expected type: " << argType
4471 << ", got: ";
4472
4473 if (yieldedTypes.empty())
4474 error << "None";
4475 else
4476 error << yieldedTypes;
4477
4478 return error;
4479 };
4480
4481 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4482 StringRef regionName,
4483 bool yieldsValue) -> LogicalResult {
4484 assert(!region.empty());
4485
4486 if (region.getNumArguments() != expectedNumArgs)
4487 return mlir::emitError(region.getLoc())
4488 << "`" << regionName << "`: "
4489 << "expected " << expectedNumArgs
4490 << " region arguments, got: " << region.getNumArguments();
4491
4492 for (Block &block : region) {
4493 // MLIR will verify the absence of the terminator for us.
4494 if (!block.mightHaveTerminator())
4495 continue;
4496
4497 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4498 return failure();
4499 }
4500
4501 return success();
4502 };
4503
4504 // Ensure all of the region arguments have the same type
4505 for (Region *region : getRegions())
4506 for (Type ty : region->getArgumentTypes())
4507 if (ty != argType)
4508 return emitError() << "Region argument type mismatch: got " << ty
4509 << " expected " << argType << ".";
4510
4511 mlir::Region &initRegion = getInitRegion();
4512 if (!initRegion.empty() &&
4513 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4514 /*yieldsValue=*/true)))
4515 return failure();
4516
4517 DataSharingClauseType dsType = getDataSharingType();
4518
4519 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4520 return emitError("`private` clauses do not require a `copy` region.");
4521
4522 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4523 return emitError(
4524 "`firstprivate` clauses require at least a `copy` region.");
4525
4526 if (dsType == DataSharingClauseType::FirstPrivate &&
4527 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4528 /*yieldsValue=*/true)))
4529 return failure();
4530
4531 if (!getDeallocRegion().empty() &&
4532 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4533 /*yieldsValue=*/false)))
4534 return failure();
4535
4536 return success();
4537}
4538
4539//===----------------------------------------------------------------------===//
4540// Spec 5.2: Masked construct (10.5)
4541//===----------------------------------------------------------------------===//
4542
4543void MaskedOp::build(OpBuilder &builder, OperationState &state,
4544 const MaskedOperands &clauses) {
4545 MaskedOp::build(builder, state, clauses.filteredThreadId);
4546}
4547
4548//===----------------------------------------------------------------------===//
4549// Spec 5.2: Scan construct (5.6)
4550//===----------------------------------------------------------------------===//
4551
4552void ScanOp::build(OpBuilder &builder, OperationState &state,
4553 const ScanOperands &clauses) {
4554 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4555}
4556
4557LogicalResult ScanOp::verify() {
4558 if (hasExclusiveVars() == hasInclusiveVars())
4559 return emitError(
4560 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4561 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4562 if (parentWsLoopOp.getReductionModAttr() &&
4563 parentWsLoopOp.getReductionModAttr().getValue() ==
4564 ReductionModifier::inscan)
4565 return success();
4566 }
4567 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4568 if (parentSimdOp.getReductionModAttr() &&
4569 parentSimdOp.getReductionModAttr().getValue() ==
4570 ReductionModifier::inscan)
4571 return success();
4572 }
4573 return emitError("SCAN directive needs to be enclosed within a parent "
4574 "worksharing loop construct or SIMD construct with INSCAN "
4575 "reduction modifier");
4576}
4577
4578/// Verifies align clause in allocate directive
4579
4580LogicalResult AllocateDirOp::verify() {
4581 std::optional<uint64_t> align = this->getAlign();
4582
4583 if (align.has_value()) {
4584 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4585 return emitError() << "ALIGN value : " << align.value()
4586 << " must be power of 2";
4587 }
4588
4589 return success();
4590}
4591
4592//===----------------------------------------------------------------------===//
4593// TargetAllocMemOp
4594//===----------------------------------------------------------------------===//
4595
4596mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4597 return getInTypeAttr().getValue();
4598}
4599
4600/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4601/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4602/// attr-dict-without-keyword
4603static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4605 auto &builder = parser.getBuilder();
4606 bool hasOperands = false;
4607 std::int32_t typeparamsSize = 0;
4608
4609 // Parse device number as a new operand
4611 mlir::Type deviceType;
4612 if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4613 return mlir::failure();
4614 if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4615 return mlir::failure();
4616 if (parser.parseComma())
4617 return mlir::failure();
4618
4619 mlir::Type intype;
4620 if (parser.parseType(intype))
4621 return mlir::failure();
4622 result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4625 if (!parser.parseOptionalLParen()) {
4626 // parse the LEN params of the derived type. (<params> : <types>)
4628 parser.parseColonTypeList(typeVec) || parser.parseRParen())
4629 return mlir::failure();
4630 typeparamsSize = operands.size();
4631 hasOperands = true;
4632 }
4633 std::int32_t shapeSize = 0;
4634 if (!parser.parseOptionalComma()) {
4635 // parse size to scale by, vector of n dimensions of type index
4637 return mlir::failure();
4638 shapeSize = operands.size() - typeparamsSize;
4639 auto idxTy = builder.getIndexType();
4640 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4641 typeVec.push_back(idxTy);
4642 hasOperands = true;
4643 }
4644 if (hasOperands &&
4645 parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4646 result.operands))
4647 return mlir::failure();
4648
4649 mlir::Type restype = builder.getIntegerType(64);
4650 if (!restype) {
4651 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4652 return mlir::failure();
4653 }
4654 llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4655 result.addAttribute("operandSegmentSizes",
4656 builder.getDenseI32ArrayAttr(segmentSizes));
4657 if (parser.parseOptionalAttrDict(result.attributes) ||
4658 parser.addTypeToList(restype, result.types))
4659 return mlir::failure();
4660 return mlir::success();
4661}
4662
4663mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4665 return parseTargetAllocMemOp(parser, result);
4666}
4667
4668void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
4669 p << " ";
4671 p << " : ";
4672 p << getDevice().getType();
4673 p << ", ";
4674 p << getInType();
4675 if (!getTypeparams().empty()) {
4676 p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4677 }
4678 for (auto sh : getShape()) {
4679 p << ", ";
4680 p.printOperand(sh);
4681 }
4682 p.printOptionalAttrDict((*this)->getAttrs(),
4683 {"in_type", "operandSegmentSizes"});
4684}
4685
4686llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4687 mlir::Type outType = getType();
4688 if (!mlir::dyn_cast<IntegerType>(outType))
4689 return emitOpError("must be a integer type");
4690 return mlir::success();
4691}
4692
4693//===----------------------------------------------------------------------===//
4694// WorkdistributeOp
4695//===----------------------------------------------------------------------===//
4696
4697LogicalResult WorkdistributeOp::verify() {
4698 // Check that region exists and is not empty
4699 Region &region = getRegion();
4700 if (region.empty())
4701 return emitOpError("region cannot be empty");
4702 // Verify single entry point.
4703 Block &entryBlock = region.front();
4704 if (entryBlock.empty())
4705 return emitOpError("region must contain a structured block");
4706 // Verify single exit point.
4707 bool hasTerminator = false;
4708 for (Block &block : region) {
4709 if (isa<TerminatorOp>(block.back())) {
4710 if (hasTerminator) {
4711 return emitOpError("region must have exactly one terminator");
4712 }
4713 hasTerminator = true;
4714 }
4715 }
4716 if (!hasTerminator) {
4717 return emitOpError("region must be terminated with omp.terminator");
4718 }
4719 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4720 // No implicit barrier at end
4721 if (isa<BarrierOp>(op)) {
4722 return emitOpError(
4723 "explicit barriers are not allowed in workdistribute region");
4724 }
4725 // Check for invalid nested constructs
4726 if (isa<ParallelOp>(op)) {
4727 return emitOpError(
4728 "nested parallel constructs not allowed in workdistribute");
4729 }
4730 if (isa<TeamsOp>(op)) {
4731 return emitOpError(
4732 "nested teams constructs not allowed in workdistribute");
4733 }
4734 return WalkResult::advance();
4735 });
4736 if (walkResult.wasInterrupted())
4737 return failure();
4738
4739 Operation *parentOp = (*this)->getParentOp();
4740 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4741 return emitOpError("workdistribute must be nested under teams");
4742 return success();
4743}
4744
4745//===----------------------------------------------------------------------===//
4746// Declare simd [7.7]
4747//===----------------------------------------------------------------------===//
4748
4749LogicalResult DeclareSimdOp::verify() {
4750 // Must be nested inside a function-like op
4751 auto func =
4752 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4753 if (!func)
4754 return emitOpError() << "must be nested inside a function";
4755
4756 if (getInbranch() && getNotinbranch())
4757 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
4758
4759 if (failed(verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars(),
4760 /*isDeclareSimd=*/true)))
4761 return failure();
4762
4763 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
4764}
4765
4766void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4767 const DeclareSimdOperands &clauses) {
4768 MLIRContext *ctx = odsBuilder.getContext();
4769 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4770 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
4771 clauses.linearVars, clauses.linearStepVars,
4772 clauses.linearVarTypes, clauses.linearModifiers,
4773 clauses.notinbranch, clauses.simdlen,
4774 clauses.uniformVars);
4775}
4776
4777//===----------------------------------------------------------------------===//
4778// Parser and printer for Uniform Clause
4779//===----------------------------------------------------------------------===//
4780
4781/// uniform ::= `uniform` `(` uniform-list `)`
4782/// uniform-list := uniform-val (`,` uniform-val)*
4783/// uniform-val := ssa-id `:` type
4784static ParseResult
4787 SmallVectorImpl<Type> &uniformTypes) {
4788 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
4789 if (parser.parseOperand(uniformVars.emplace_back()) ||
4790 parser.parseColonType(uniformTypes.emplace_back()))
4791 return mlir::failure();
4792 return mlir::success();
4793 });
4794}
4795
4796/// Print Uniform Clauses
4798 ValueRange uniformVars, TypeRange uniformTypes) {
4799 for (unsigned i = 0; i < uniformVars.size(); ++i) {
4800 if (i != 0)
4801 p << ", ";
4802 p << uniformVars[i] << " : " << uniformTypes[i];
4803 }
4804}
4805
4806//===----------------------------------------------------------------------===//
4807// Parser and printer for Affinity Clause
4808//===----------------------------------------------------------------------===//
4809
4810static ParseResult parseAffinityClause(
4811 OpAsmParser &parser,
4814 SmallVectorImpl<Type> &iteratedTypes,
4815 SmallVectorImpl<Type> &affinityVarTypes) {
4816 if (failed(parseSplitIteratedList(
4817 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4818 /*parsePrefix=*/[&]() -> ParseResult { return success(); })))
4819 return failure();
4820 return success();
4821}
4822
4824 ValueRange iterated, ValueRange affinityVars,
4825 TypeRange iteratedTypes,
4826 TypeRange affinityVarTypes) {
4827 auto nop = [&](Value, Type) {};
4828 printSplitIteratedList(p, iterated, iteratedTypes, affinityVars,
4829 affinityVarTypes,
4830 /*plain prefix*/ nop,
4831 /*iterated prefix*/ nop);
4832}
4833
4834//===----------------------------------------------------------------------===//
4835// Parser, printer, and verifier for Iterator modifier
4836//===----------------------------------------------------------------------===//
4837
4838static ParseResult
4843 SmallVectorImpl<Type> &lbTypes,
4844 SmallVectorImpl<Type> &ubTypes,
4845 SmallVectorImpl<Type> &stepTypes) {
4846
4847 llvm::SMLoc ivLoc = parser.getCurrentLocation();
4849
4850 // Parse induction variables: %i : i32, %j : i32
4851 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4852 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4853 if (parser.parseArgument(arg))
4854 return failure();
4855
4856 // Optional type, default to Index if not provided
4857 if (succeeded(parser.parseOptionalColon())) {
4858 if (parser.parseType(arg.type))
4859 return failure();
4860 } else {
4861 arg.type = parser.getBuilder().getIndexType();
4862 }
4863 return success();
4864 }))
4865 return failure();
4866
4867 // ) = (
4868 if (parser.parseRParen() || parser.parseEqual() || parser.parseLParen())
4869 return failure();
4870
4871 // Parse Ranges: (%lb to %ub step %st, ...)
4872 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4873 OpAsmParser::UnresolvedOperand lb, ub, st;
4874 if (parser.parseOperand(lb) || parser.parseKeyword("to") ||
4875 parser.parseOperand(ub) || parser.parseKeyword("step") ||
4876 parser.parseOperand(st))
4877 return failure();
4878
4879 lbs.push_back(lb);
4880 ubs.push_back(ub);
4881 steps.push_back(st);
4882 return success();
4883 }))
4884 return failure();
4885
4886 if (parser.parseRParen())
4887 return failure();
4888
4889 if (ivArgs.size() != lbs.size())
4890 return parser.emitError(ivLoc)
4891 << "mismatch: " << ivArgs.size() << " variables but " << lbs.size()
4892 << " ranges";
4893
4894 for (auto &arg : ivArgs) {
4895 lbTypes.push_back(arg.type);
4896 ubTypes.push_back(arg.type);
4897 stepTypes.push_back(arg.type);
4898 }
4899
4900 return parser.parseRegion(region, ivArgs);
4901}
4902
4904 ValueRange lbs, ValueRange ubs,
4906 TypeRange) {
4907 Block &entry = region.front();
4908
4909 for (unsigned i = 0, e = entry.getNumArguments(); i < e; ++i) {
4910 if (i != 0)
4911 p << ", ";
4912 p.printRegionArgument(entry.getArgument(i));
4913 }
4914 p << ") = (";
4915
4916 // (%lb0 to %ub0 step %step0, %lb1 to %ub1 step %step1, ...)
4917 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
4918 if (i)
4919 p << ", ";
4920 p << lbs[i] << " to " << ubs[i] << " step " << steps[i];
4921 }
4922 p << ") ";
4923
4924 p.printRegion(region, /*printEntryBlockArgs=*/false,
4925 /*printBlockTerminators=*/true);
4926}
4927
4928LogicalResult IteratorOp::verify() {
4929 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().getType());
4930 if (!iteratedTy)
4931 return emitOpError() << "result must be omp.iterated<entry_ty>";
4932
4933 for (auto [lb, ub, step] : llvm::zip_equal(
4934 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
4935 if (matchPattern(step, m_Zero()))
4936 return emitOpError() << "loop step must not be zero";
4937
4938 IntegerAttr lbAttr;
4939 IntegerAttr ubAttr;
4940 IntegerAttr stepAttr;
4941 if (!matchPattern(lb, m_Constant(&lbAttr)) ||
4942 !matchPattern(ub, m_Constant(&ubAttr)) ||
4943 !matchPattern(step, m_Constant(&stepAttr)))
4944 continue;
4945
4946 const APInt &lbVal = lbAttr.getValue();
4947 const APInt &ubVal = ubAttr.getValue();
4948 const APInt &stepVal = stepAttr.getValue();
4949 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
4950 return emitOpError() << "positive loop step requires lower bound to be "
4951 "less than or equal to upper bound";
4952 if (stepVal.isNegative() && lbVal.slt(ubVal))
4953 return emitOpError() << "negative loop step requires lower bound to be "
4954 "greater than or equal to upper bound";
4955 }
4956
4957 Block &b = getRegion().front();
4958 auto yield = llvm::dyn_cast<omp::YieldOp>(b.getTerminator());
4959
4960 if (!yield)
4961 return emitOpError() << "region must be terminated by omp.yield";
4962
4963 if (yield.getNumOperands() != 1)
4964 return emitOpError()
4965 << "omp.yield in omp.iterator region must yield exactly one value";
4966
4967 mlir::Type yieldedTy = yield.getOperand(0).getType();
4968 mlir::Type elemTy = iteratedTy.getElementType();
4969
4970 if (yieldedTy != elemTy)
4971 return emitOpError() << "omp.iterated element type (" << elemTy
4972 << ") does not match omp.yield operand type ("
4973 << yieldedTy << ")";
4974
4975 return success();
4976}
4977
4978#define GET_ATTRDEF_CLASSES
4979#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4980
4981#define GET_OP_CLASSES
4982#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4983
4984#define GET_TYPEDEF_CLASSES
4985#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1497
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.