MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Matchers.h"
24#include "mlir/IR/SymbolTable.h"
27
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cstddef>
39#include <iterator>
40#include <optional>
41#include <variant>
42
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
47
48using namespace mlir;
49using namespace mlir::omp;
50
53 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
54}
55
58 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
59}
60
63 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
64}
65
66namespace {
67struct MemRefPointerLikeModel
68 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 MemRefType> {
70 Type getElementType(Type pointer) const {
71 return llvm::cast<MemRefType>(pointer).getElementType();
72 }
73};
74
75struct LLVMPointerPointerLikeModel
76 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
78 Type getElementType(Type pointer) const { return Type(); }
79};
80} // namespace
81
82/// Generate a name of a canonical loop nest of the format
83/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
84/// argument index of an operation that has multiple regions, if the operation
85/// has multiple regions.
86/// `_s<idx>` identifies the position of an operation within a region, where
87/// only operations that may potentially contain loops ("container operations"
88/// i.e. have region arguments) are counted. Again, it is omitted if there is
89/// only one such operation in a region. If there are canonical loops nested
90/// inside each other, also may also use the format `_d<num>` where <num> is the
91/// nesting depth of the loop.
92///
93/// The generated name is a best-effort to make canonical loop unique within an
94/// SSA namespace. This also means that regions with IsolatedFromAbove property
95/// do not consider any parents or siblings.
96static std::string generateLoopNestingName(StringRef prefix,
97 CanonicalLoopOp op) {
98 struct Component {
99 /// If true, this component describes a region operand of an operation (the
100 /// operand's owner) If false, this component describes an operation located
101 /// in a parent region
102 bool isRegionArgOfOp;
103 bool skip = false;
104 bool isUnique = false;
105
106 size_t idx;
107 Operation *op;
108 Region *parentRegion;
109 size_t loopDepth;
110
111 Operation *&getOwnerOp() {
112 assert(isRegionArgOfOp && "Must describe a region operand");
113 return op;
114 }
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp && "Must describe a region operand");
117 return idx;
118 }
119
120 Operation *&getContainerOp() {
121 assert(!isRegionArgOfOp && "Must describe a operation of a region");
122 return op;
123 }
124 size_t &getOpPos() {
125 assert(!isRegionArgOfOp && "Must describe a operation of a region");
126 return idx;
127 }
128 bool isLoopOp() const {
129 assert(!isRegionArgOfOp && "Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
131 }
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp && "Must describe a operation of a region");
134 return parentRegion;
135 }
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp && "Must describe a operation of a region");
138 return loopDepth;
139 }
140
141 void skipIf(bool v = true) { skip = skip || v; }
142 };
143
144 // List of ancestors, from inner to outer.
145 // Alternates between
146 // * region argument of an operation
147 // * operation within a region
148 SmallVector<Component> components;
149
150 // Gather a list of parent regions and operations, and the position within
151 // their parent
152 Operation *o = op.getOperation();
153 while (o) {
154 // Operation within a region
155 Region *r = o->getParentRegion();
156 if (!r)
157 break;
158
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
160 size_t idx = 0;
161 bool found = false;
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp = true;
164 for (Block *b : traversal) {
165 for (Operation &op : *b) {
166 if (&op == o && !found) {
167 sequentialIdx = idx;
168 found = true;
169 }
170 if (op.getNumRegions()) {
171 idx += 1;
172 if (idx > 1)
173 isOnlyContainerOp = false;
174 }
175 if (found && !isOnlyContainerOp)
176 break;
177 }
178 }
179
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp = false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
186
187 Operation *parent = r->getParentOp();
188
189 // Region argument of an operation
190 Component &regionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp = true;
192 regionArgOfOperation.isUnique = true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
195
196 // The IsolatedFromAbove trait of the parent operation implies that each
197 // individual region argument has its own separate namespace, so no
198 // ambiguity.
199 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
200 break;
201
202 // Component only needed if operation has multiple region operands. Region
203 // arguments may be optional, but we currently do not consider this.
204 if (parent->getRegions().size() > 1) {
205 auto getRegionIndex = [](Operation *o, Region *r) {
206 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
207 if (&region == r)
208 return idx;
209 }
210 llvm_unreachable("Region not child of its parent operation");
211 };
212 regionArgOfOperation.isUnique = false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
214 }
215
216 // next parent
217 o = parent;
218 }
219
220 // Determine whether a region-argument component is not needed
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
223
224 // Find runs of nested loops and determine each loop's depth in the loop nest
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
227 if (c.skip)
228 continue;
229
230 // non-skipped multi-argument operands interrupt the loop nest
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
233 continue;
234 }
235
236 // Multiple loops in a region means each of them is the outermost loop of a
237 // new loop nest
238 if (!c.isUnique)
239 numSurroundingLoops = 0;
240
241 c.getLoopDepth() = numSurroundingLoops;
242
243 // Next loop is surrounded by one more loop
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
246 }
247
248 // In loop nests, skip all but the innermost loop that contains the depth
249 // number
250 bool isLoopNest = false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
253 continue;
254
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
256 // Innermost loop of a loop nest of at least two loops
257 isLoopNest = true;
258 } else if (isLoopNest) {
259 // Non-innermost loop of a loop nest
260 c.skipIf(c.isUnique);
261
262 // If there is no surrounding loop left, this must have been the outermost
263 // loop; leave loop-nest mode for the next iteration
264 if (c.getLoopDepth() == 0)
265 isLoopNest = false;
266 }
267 }
268
269 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
270 // we mark them as skipped only after computing loop nests)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
274
275 // Components can be skipped if they are already disambiguated by their parent
276 // (or does not have a parent)
277 bool newRegion = true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
280
281 // non-skipped components disambiguate unique children
282 if (!c.skip)
283 newRegion = true;
284
285 // ...except canonical loops that need a suffix for each nest
286 if (!c.isRegionArgOfOp && c.getContainerOp())
287 newRegion = false;
288 }
289
290 // Compile the nesting name string
291 SmallString<64> Name{prefix};
292 llvm::raw_svector_ostream NameOS(Name);
293 for (auto &c : llvm::reverse(components)) {
294 if (c.skip)
295 continue;
296
297 if (c.isRegionArgOfOp)
298 NameOS << "_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS << "_d" << c.getLoopDepth();
301 else
302 NameOS << "_s" << c.getOpPos();
303 }
304
305 return NameOS.str().str();
306}
307
308void OpenMPDialect::initialize() {
309 addOperations<
310#define GET_OP_LIST
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
312 >();
313 addAttributes<
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
316 >();
317 addTypes<
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
320 >();
321
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
323
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
326 *getContext());
327
328 // Attach default offload module interface to module op to access
329 // offload functionality through
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
331 *getContext());
332
333 // Attach default declare target interfaces to operations which can be marked
334 // as declare target (Global Operations and Functions/Subroutines in dialects
335 // that Fortran (or other languages that lower to MLIR) translates too
336 mlir::LLVM::GlobalOp::attachInterface<
338 *getContext());
339 mlir::LLVM::LLVMFuncOp::attachInterface<
341 *getContext());
342 mlir::func::FuncOp::attachInterface<
344}
345
346//===----------------------------------------------------------------------===//
347// Parser and printer for Allocate Clause
348//===----------------------------------------------------------------------===//
349
350/// Parse an allocate clause with allocators and a list of operands with types.
351///
352/// allocate-operand-list :: = allocate-operand |
353/// allocator-operand `,` allocate-operand-list
354/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
355/// ssa-id-and-type ::= ssa-id `:` type
356static ParseResult parseAllocateAndAllocator(
357 OpAsmParser &parser,
359 SmallVectorImpl<Type> &allocateTypes,
361 SmallVectorImpl<Type> &allocatorTypes) {
362
363 return parser.parseCommaSeparatedList([&]() {
365 Type type;
366 if (parser.parseOperand(operand) || parser.parseColonType(type))
367 return failure();
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
370 if (parser.parseArrow())
371 return failure();
372 if (parser.parseOperand(operand) || parser.parseColonType(type))
373 return failure();
374
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
377 return success();
378 });
379}
380
381/// Print allocate clause
383 OperandRange allocateVars,
384 TypeRange allocateTypes,
385 OperandRange allocatorVars,
386 TypeRange allocatorTypes) {
387 for (unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
389 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
390 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
391 }
392}
393
394//===----------------------------------------------------------------------===//
395// Parser and printer for a clause attribute (StringEnumAttr)
396//===----------------------------------------------------------------------===//
397
398template <typename ClauseAttr>
399static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
400 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
401 StringRef enumStr;
402 SMLoc loc = parser.getCurrentLocation();
403 if (parser.parseKeyword(&enumStr))
404 return failure();
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.getContext(), *enumValue);
407 return success();
408 }
409 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
410}
411
412template <typename ClauseAttr>
413static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
414 p << stringifyEnum(attr.getValue());
415}
416
417//===----------------------------------------------------------------------===//
418// Parser and printer for Linear Clause
419//===----------------------------------------------------------------------===//
420
421/// linear ::= `linear` `(` linear-list `)`
422/// linear-list := linear-val | linear-val linear-list
423/// linear-val := ssa-id-and-type `=` ssa-id-and-type
424/// | `val` `(` ssa-id-and-type `=` ssa-id-and-type `)`
425/// | `ref` `(` ssa-id-and-type `=` ssa-id-and-type `)`
426/// | `uval` `(` ssa-id-and-type `=` ssa-id-and-type `)`
427static ParseResult parseLinearClause(
428 OpAsmParser &parser,
430 SmallVectorImpl<Type> &linearTypes,
432 SmallVectorImpl<Type> &linearStepTypes, ArrayAttr &linearModifiers) {
433 SmallVector<Attribute> modifiers;
434 auto result = parser.parseCommaSeparatedList([&]() {
436 Type type, stepType;
438
439 std::optional<omp::LinearModifier> linearModifier;
440 if (succeeded(parser.parseOptionalKeyword("val"))) {
441 linearModifier = omp::LinearModifier::val;
442 } else if (succeeded(parser.parseOptionalKeyword("ref"))) {
443 linearModifier = omp::LinearModifier::ref;
444 } else if (succeeded(parser.parseOptionalKeyword("uval"))) {
445 linearModifier = omp::LinearModifier::uval;
446 }
447
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.parseLParen())
450 return failure();
451
452 if (parser.parseOperand(var) || parser.parseColonType(type) ||
453 parser.parseEqual() || parser.parseOperand(stepVar) ||
454 parser.parseColonType(stepType))
455 return failure();
456
457 if (hasLinearModifierParens && parser.parseRParen())
458 return failure();
459
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
465 modifiers.push_back(
466 omp::LinearModifierAttr::get(parser.getContext(), *linearModifier));
467 } else {
468 modifiers.push_back(UnitAttr::get(parser.getContext()));
469 }
470 return success();
471 });
472 if (failed(result))
473 return failure();
474 linearModifiers = ArrayAttr::get(parser.getContext(), modifiers);
475 return success();
476}
477
478/// Print Linear Clause
480 ValueRange linearVars, TypeRange linearTypes,
481 ValueRange linearStepVars, TypeRange stepVarTypes,
482 ArrayAttr linearModifiers) {
483 size_t linearVarsSize = linearVars.size();
484 for (unsigned i = 0; i < linearVarsSize; ++i) {
485 if (i != 0)
486 p << ", ";
487 // Print modifier keyword wrapper if present.
488 Attribute modAttr = linearModifiers ? linearModifiers[i] : nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) : nullptr;
490 if (mod) {
491 p << omp::stringifyLinearModifier(mod.getValue()) << "(";
492 }
493 p << linearVars[i] << " : " << linearTypes[i];
494 p << " = " << linearStepVars[i] << " : " << stepVarTypes[i];
495 if (mod)
496 p << ")";
497 }
498}
499
500//===----------------------------------------------------------------------===//
501// Verifier for Linear modifier
502//===----------------------------------------------------------------------===//
503
504/// OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or
505/// uval only on a declare simd directive."
506/// Also verifies that modifier count matches variable count.
507static LogicalResult
508verifyLinearModifiers(Operation *op, std::optional<ArrayAttr> linearModifiers,
509 OperandRange linearVars, bool isDeclareSimd = false) {
510 if (!linearModifiers)
511 return success();
512 if (linearModifiers->size() != linearVars.size())
513 return op->emitOpError()
514 << "expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (Attribute attr : *linearModifiers) {
517 if (!attr)
518 continue;
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
520 if (!modAttr)
521 continue;
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
524 return op->emitOpError()
525 << "linear modifier '" << omp::stringifyLinearModifier(mod)
526 << "' may only be specified on a declare simd directive";
527 }
528 }
529 return success();
530}
531
532//===----------------------------------------------------------------------===//
533// Verifier for Nontemporal Clause
534//===----------------------------------------------------------------------===//
535
536static LogicalResult verifyNontemporalClause(Operation *op,
537 OperandRange nontemporalVars) {
538
539 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
540 DenseSet<Value> nontemporalItems;
541 for (const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->emitOpError() << "nontemporal variable used more than once";
544
545 return success();
546}
547
548//===----------------------------------------------------------------------===//
549// Parser, verifier and printer for Aligned Clause
550//===----------------------------------------------------------------------===//
551static LogicalResult verifyAlignedClause(Operation *op,
552 std::optional<ArrayAttr> alignments,
553 OperandRange alignedVars) {
554 // Check if number of alignment values equals to number of aligned variables
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
557 return op->emitOpError()
558 << "expected as many alignment values as aligned variables";
559 } else {
560 if (alignments)
561 return op->emitOpError() << "unexpected alignment values attribute";
562 return success();
563 }
564
565 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
566 DenseSet<Value> alignedItems;
567 for (auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->emitOpError() << "aligned variable used more than once";
570
571 if (!alignments)
572 return success();
573
574 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
575 for (unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->emitOpError() << "alignment should be greater than 0";
579 } else {
580 return op->emitOpError() << "expected integer alignment";
581 }
582 }
583
584 return success();
585}
586
587/// aligned ::= `aligned` `(` aligned-list `)`
588/// aligned-list := aligned-val | aligned-val aligned-list
589/// aligned-val := ssa-id-and-type `->` alignment
590static ParseResult
593 SmallVectorImpl<Type> &alignedTypes,
594 ArrayAttr &alignmentsAttr) {
595 SmallVector<Attribute> alignmentVec;
596 if (failed(parser.parseCommaSeparatedList([&]() {
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
601 return failure();
602 }
603 return success();
604 })))
605 return failure();
606 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
608 return success();
609}
610
611/// Print Aligned Clause
613 ValueRange alignedVars, TypeRange alignedTypes,
614 std::optional<ArrayAttr> alignments) {
615 for (unsigned i = 0; i < alignedVars.size(); ++i) {
616 if (i != 0)
617 p << ", ";
618 p << alignedVars[i] << " : " << alignedVars[i].getType();
619 p << " -> " << (*alignments)[i];
620 }
621}
622
623//===----------------------------------------------------------------------===//
624// Parser, printer and verifier for Schedule Clause
625//===----------------------------------------------------------------------===//
626
627static ParseResult
629 SmallVectorImpl<SmallString<12>> &modifiers) {
630 if (modifiers.size() > 2)
631 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
632 for (const auto &mod : modifiers) {
633 // Translate the string. If it has no value, then it was not a valid
634 // modifier!
635 auto symbol = symbolizeScheduleModifier(mod);
636 if (!symbol)
637 return parser.emitError(parser.getNameLoc())
638 << " unknown modifier type: " << mod;
639 }
640
641 // If we have one modifier that is "simd", then stick a "none" modiifer in
642 // index 0.
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
647 }
648 } else if (modifiers.size() == 2) {
649 // If there are two modifier:
650 // First modifier should not be simd, second one should be simd
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
653 return parser.emitError(parser.getNameLoc())
654 << " incorrect modifier order";
655 }
656 return success();
657}
658
659/// schedule ::= `schedule` `(` sched-list `)`
660/// sched-list ::= sched-val | sched-val sched-list |
661/// sched-val `,` sched-modifier
662/// sched-val ::= sched-with-chunk | sched-wo-chunk
663/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
664/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
665/// sched-wo-chunk ::= `auto` | `runtime`
666/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
667/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
668static ParseResult
669parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
672 Type &chunkType) {
673 StringRef keyword;
674 if (parser.parseKeyword(&keyword))
675 return failure();
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
678 if (!schedule)
679 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
680
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
682 switch (*schedule) {
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
686 if (succeeded(parser.parseOptionalEqual())) {
687 chunkSize = OpAsmParser::UnresolvedOperand{};
688 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
689 return failure();
690 } else {
691 chunkSize = std::nullopt;
692 }
693 break;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
698 }
699
700 // If there is a comma, we have one or more modifiers..
702 while (succeeded(parser.parseOptionalComma())) {
703 StringRef mod;
704 if (parser.parseKeyword(&mod))
705 return failure();
706 modifiers.push_back(mod);
707 }
708
709 if (verifyScheduleModifiers(parser, modifiers))
710 return failure();
711
712 if (!modifiers.empty()) {
713 SMLoc loc = parser.getCurrentLocation();
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
717 } else {
718 return parser.emitError(loc, "invalid schedule modifier");
719 }
720 // Only SIMD attribute is allowed here!
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
723 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
724 }
725 }
726
727 return success();
728}
729
730/// Print schedule clause
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd, Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
737 if (scheduleChunk)
738 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
739 if (scheduleMod)
740 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
741 if (scheduleSimd)
742 p << ", simd";
743}
744
745//===----------------------------------------------------------------------===//
746// Parser and printer for Order Clause
747//===----------------------------------------------------------------------===//
748
749// order ::= `order` `(` [order-modifier ':'] concurrent `)`
750// order-modifier ::= reproducible | unconstrained
751static ParseResult parseOrderClause(OpAsmParser &parser,
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
754 StringRef enumStr;
755 SMLoc loc = parser.getCurrentLocation();
756 if (parser.parseKeyword(&enumStr))
757 return failure();
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
761 if (parser.parseOptionalColon())
762 return failure();
763 loc = parser.getCurrentLocation();
764 if (parser.parseKeyword(&enumStr))
765 return failure();
766 }
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
770 return success();
771 }
772 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
773}
774
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
778 if (orderMod)
779 p << stringifyOrderModifier(orderMod.getValue()) << ":";
780 if (order)
781 p << stringifyClauseOrderKind(order.getValue());
782}
783
784template <typename ClauseTypeAttr, typename ClauseType>
785static ParseResult
786parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
788 Type &operandType,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
791 StringRef enumStr;
792 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
795 if (parser.parseComma())
796 return failure();
797 } else {
798 return parser.emitError(parser.getCurrentLocation())
799 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
800 ;
801 }
802 }
803
805 if (succeeded(parser.parseOperand(var))) {
806 operand = var;
807 } else {
808 return parser.emitError(parser.getCurrentLocation())
809 << "expected " << clauseName << " operand";
810 }
811
812 if (operand.has_value()) {
813 if (parser.parseColonType(operandType))
814 return failure();
815 }
816
817 return success();
818}
819
820template <typename ClauseTypeAttr, typename ClauseType>
821static void
823 ClauseTypeAttr prescriptiveness, Value operand,
824 mlir::Type operandType,
825 StringRef (*stringifyClauseType)(ClauseType)) {
826
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
829
830 if (operand)
831 p << operand << ": " << operandType;
832}
833
834//===----------------------------------------------------------------------===//
835// Parser and printer for grainsize Clause
836//===----------------------------------------------------------------------===//
837
838// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
839static ParseResult
840parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType, "grainsize");
846}
847
849 ClauseGrainsizeTypeAttr grainsizeMod,
850 Value grainsize, mlir::Type grainsizeType) {
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
854}
855
856//===----------------------------------------------------------------------===//
857// Parser and printer for num_tasks Clause
858//===----------------------------------------------------------------------===//
859
860// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
861static ParseResult
862parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
867 "num_tasks");
868}
869
871 ClauseNumTasksTypeAttr numTasksMod,
872 Value numTasks, mlir::Type numTasksType) {
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
875}
876
877//===----------------------------------------------------------------------===//
878// Parser and printer for Heap Alloc Clause
879//===----------------------------------------------------------------------===//
880
881/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
882static ParseResult parseHeapAllocClause(
883 OpAsmParser &parser, TypeAttr &inTypeAttr,
885 SmallVectorImpl<Type> &typeparamsTypes,
887 SmallVectorImpl<Type> &shapeTypes) {
888 mlir::Type inType;
889 if (parser.parseType(inType))
890 return mlir::failure();
891 inTypeAttr = TypeAttr::get(inType);
892
893 if (!parser.parseOptionalLParen()) {
894 // parse the LEN params of the derived type. (<params> : <types>)
895 if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) ||
896 parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen())
897 return failure();
898 }
899
900 if (!parser.parseOptionalComma()) {
901 // parse size to scale by, vector of n dimensions of type index
903 return failure();
904
905 // TODO: This overrides the actual types of the operands, which might cause
906 // issues when they don't match. At the moment this is done in place of
907 // making the corresponding operand type `Variadic<Index>` because index
908 // types are lowered to I64 prior to LLVM IR translation.
909 shapeTypes.append(shape.size(), IndexType::get(parser.getContext()));
910 }
911
912 return success();
913}
914
916 TypeAttr inType, ValueRange typeparams,
917 TypeRange typeparamsTypes, ValueRange shape,
918 TypeRange shapeTypes) {
919 p << inType;
920 if (!typeparams.empty()) {
921 p << '(' << typeparams << " : " << typeparamsTypes << ')';
922 }
923 for (auto sh : shape) {
924 p << ", ";
925 p.printOperand(sh);
926 }
927}
928
929//===----------------------------------------------------------------------===//
930// Parser, printer and verify for dyn_groupprivate Clause
931//===----------------------------------------------------------------------===//
932
933static LogicalResult
934verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup,
935 FallbackModifierAttr fallback,
936 Value dynGroupprivateSize) {
937 if (!dynGroupprivateSize && (accessGroup || fallback))
938 return op->emitOpError("dyn_groupprivate modifiers require a size operand");
939
940 return success();
941}
942
943static ParseResult parseDynGroupprivateClause(
944 OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
945 FallbackModifierAttr &fallbackAttr,
946 std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
947 Type &sizeType) {
948
949 bool parsedAccessGroup = false;
950 bool parsedFallback = false;
951 bool parsedSize = false;
952
953 return parser.parseCommaSeparatedList([&]() -> ParseResult {
954 // Parse AccessGroupModifier.
955 if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
956 if (parsedAccessGroup)
957 return parser.emitError(parser.getCurrentLocation(),
958 "duplicate access group modifier");
959 accessGroupAttr = AccessGroupModifierAttr::get(
960 parser.getContext(), AccessGroupModifier::cgroup);
961 parsedAccessGroup = true;
962 return success();
963 }
964 // Parse FallbackModifier.
965 if (succeeded(parser.parseOptionalKeyword("fallback"))) {
966 if (parsedFallback)
967 return parser.emitError(parser.getCurrentLocation(),
968 "duplicate fallback modifier");
969 if (parser.parseLParen())
970 return parser.emitError(parser.getCurrentLocation(),
971 "expected '(' after 'fallback'");
972 llvm::StringRef fbKind;
973 if (parser.parseKeyword(&fbKind))
974 return parser.emitError(
975 parser.getCurrentLocation(),
976 "expected fallback modifier (abort/null/default_mem)");
977 std::optional<FallbackModifier> fbEnum;
978 if (fbKind == "abort")
979 fbEnum = FallbackModifier::abort;
980 else if (fbKind == "null")
981 fbEnum = FallbackModifier::null;
982 else if (fbKind == "default_mem")
983 fbEnum = FallbackModifier::default_mem;
984 else
985 return parser.emitError(parser.getCurrentLocation(),
986 "invalid fallback modifier '" + fbKind + "'");
987 fallbackAttr = FallbackModifierAttr::get(parser.getContext(), *fbEnum);
988 if (parser.parseRParen())
989 return parser.emitError(parser.getCurrentLocation(),
990 "expected ')' after fallback modifier");
991 parsedFallback = true;
992 return success();
993 }
994 // Parse size operand.
996 if (succeeded(parser.parseOperand(operand))) {
997 if (parsedSize)
998 return parser.emitError(parser.getCurrentLocation(),
999 "duplicate size operand");
1000 dynGroupprivateSize = operand;
1001 parsedSize = true;
1002 if (failed(parser.parseColon()) || failed(parser.parseType(sizeType)))
1003 return parser.emitError(parser.getCurrentLocation(),
1004 "expected ':' and type after size operand");
1005 return success();
1006 }
1007 return parser.emitError(parser.getCurrentLocation(),
1008 "expected dyn_groupprivate_size operand");
1009 });
1010}
1011
1013 AccessGroupModifierAttr modifierFirst,
1014 FallbackModifierAttr modifierSecond,
1015 Value dynGroupprivateSize,
1016 Type sizeType) {
1017
1018 bool needsComma = false;
1019
1020 if (modifierFirst) {
1021 printer << modifierFirst.getValue();
1022 needsComma = true;
1023 }
1024
1025 if (modifierSecond) {
1026 if (needsComma)
1027 printer << ", ";
1028 printer << "fallback(";
1029 printer << modifierSecond.getValue();
1030 printer << ")";
1031 needsComma = true;
1032 }
1033
1034 if (dynGroupprivateSize) {
1035 if (needsComma)
1036 printer << ", ";
1037 printer << dynGroupprivateSize << " : " << sizeType;
1038 }
1039}
1040
1041//===----------------------------------------------------------------------===//
1042// Parsers for operations including clauses that define entry block arguments.
1043//===----------------------------------------------------------------------===//
1044
1045namespace {
1046struct MapParseArgs {
1047 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1048 SmallVectorImpl<Type> &types;
1049 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1050 SmallVectorImpl<Type> &types)
1051 : vars(vars), types(types) {}
1052};
1053struct PrivateParseArgs {
1054 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1055 llvm::SmallVectorImpl<Type> &types;
1056 ArrayAttr &syms;
1057 UnitAttr &needsBarrier;
1058 DenseI64ArrayAttr *mapIndices;
1059 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1060 SmallVectorImpl<Type> &types, ArrayAttr &syms,
1061 UnitAttr &needsBarrier,
1062 DenseI64ArrayAttr *mapIndices = nullptr)
1063 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1064 mapIndices(mapIndices) {}
1065};
1066
1067struct ReductionParseArgs {
1068 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1069 SmallVectorImpl<Type> &types;
1070 DenseBoolArrayAttr &byref;
1071 ArrayAttr &syms;
1072 ReductionModifierAttr *modifier;
1073 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1074 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
1075 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
1076 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1077};
1078
1079struct AllRegionParseArgs {
1080 std::optional<MapParseArgs> hasDeviceAddrArgs;
1081 std::optional<MapParseArgs> hostEvalArgs;
1082 std::optional<ReductionParseArgs> inReductionArgs;
1083 std::optional<MapParseArgs> mapArgs;
1084 std::optional<PrivateParseArgs> privateArgs;
1085 std::optional<ReductionParseArgs> reductionArgs;
1086 std::optional<ReductionParseArgs> taskReductionArgs;
1087 std::optional<MapParseArgs> useDeviceAddrArgs;
1088 std::optional<MapParseArgs> useDevicePtrArgs;
1089};
1090} // namespace
1091
1092static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
1093 return "private_barrier";
1094}
1095
1096static ParseResult parseClauseWithRegionArgs(
1097 OpAsmParser &parser,
1099 SmallVectorImpl<Type> &types,
1100 SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
1101 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
1102 DenseBoolArrayAttr *byref = nullptr,
1103 ReductionModifierAttr *modifier = nullptr,
1104 UnitAttr *needsBarrier = nullptr) {
1106 SmallVector<int64_t> mapIndicesVec;
1107 SmallVector<bool> isByRefVec;
1108 unsigned regionArgOffset = regionPrivateArgs.size();
1109
1110 if (parser.parseLParen())
1111 return failure();
1112
1113 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
1114 StringRef enumStr;
1115 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
1116 parser.parseComma())
1117 return failure();
1118 std::optional<ReductionModifier> enumValue =
1119 symbolizeReductionModifier(enumStr);
1120 if (!enumValue.has_value())
1121 return failure();
1122 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
1123 if (!*modifier)
1124 return failure();
1125 }
1126
1127 if (parser.parseCommaSeparatedList([&]() {
1128 if (byref)
1129 isByRefVec.push_back(
1130 parser.parseOptionalKeyword("byref").succeeded());
1131
1132 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
1133 return failure();
1134
1135 if (parser.parseOperand(operands.emplace_back()) ||
1136 parser.parseArrow() ||
1137 parser.parseArgument(regionPrivateArgs.emplace_back()))
1138 return failure();
1139
1140 if (mapIndices) {
1141 if (parser.parseOptionalLSquare().succeeded()) {
1142 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
1143 parser.parseInteger(mapIndicesVec.emplace_back()) ||
1144 parser.parseRSquare())
1145 return failure();
1146 } else {
1147 mapIndicesVec.push_back(-1);
1148 }
1149 }
1150
1151 return success();
1152 }))
1153 return failure();
1154
1155 if (parser.parseColon())
1156 return failure();
1157
1158 if (parser.parseCommaSeparatedList([&]() {
1159 if (parser.parseType(types.emplace_back()))
1160 return failure();
1161
1162 return success();
1163 }))
1164 return failure();
1165
1166 if (operands.size() != types.size())
1167 return failure();
1168
1169 if (parser.parseRParen())
1170 return failure();
1171
1172 if (needsBarrier) {
1174 .succeeded())
1175 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
1176 }
1177
1178 auto *argsBegin = regionPrivateArgs.begin();
1179 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
1180 argsBegin + regionArgOffset + types.size());
1181 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1182 prv.type = type;
1183 }
1184
1185 if (symbols) {
1186 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
1187 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
1188 }
1189
1190 if (!mapIndicesVec.empty())
1191 *mapIndices =
1192 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
1193
1194 if (byref)
1195 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
1196
1197 return success();
1198}
1199
1200static ParseResult parseBlockArgClause(
1201 OpAsmParser &parser,
1203 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1204 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1205 if (!mapArgs)
1206 return failure();
1207
1208 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
1209 entryBlockArgs)))
1210 return failure();
1211 }
1212 return success();
1213}
1214
1215static ParseResult parseBlockArgClause(
1216 OpAsmParser &parser,
1218 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1219 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1220 if (!privateArgs)
1221 return failure();
1222
1223 if (failed(parseClauseWithRegionArgs(
1224 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1225 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1226 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
1227 return failure();
1228 }
1229 return success();
1230}
1231
1232static ParseResult parseBlockArgClause(
1233 OpAsmParser &parser,
1235 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1236 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1237 if (!reductionArgs)
1238 return failure();
1239 if (failed(parseClauseWithRegionArgs(
1240 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1241 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1242 reductionArgs->modifier)))
1243 return failure();
1244 }
1245 return success();
1246}
1247
1248static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1249 AllRegionParseArgs args) {
1251
1252 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1253 args.hasDeviceAddrArgs)))
1254 return parser.emitError(parser.getCurrentLocation())
1255 << "invalid `has_device_addr` format";
1256
1257 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1258 args.hostEvalArgs)))
1259 return parser.emitError(parser.getCurrentLocation())
1260 << "invalid `host_eval` format";
1261
1262 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1263 args.inReductionArgs)))
1264 return parser.emitError(parser.getCurrentLocation())
1265 << "invalid `in_reduction` format";
1266
1267 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1268 args.mapArgs)))
1269 return parser.emitError(parser.getCurrentLocation())
1270 << "invalid `map_entries` format";
1271
1272 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1273 args.privateArgs)))
1274 return parser.emitError(parser.getCurrentLocation())
1275 << "invalid `private` format";
1276
1277 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1278 args.reductionArgs)))
1279 return parser.emitError(parser.getCurrentLocation())
1280 << "invalid `reduction` format";
1281
1282 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1283 args.taskReductionArgs)))
1284 return parser.emitError(parser.getCurrentLocation())
1285 << "invalid `task_reduction` format";
1286
1287 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1288 args.useDeviceAddrArgs)))
1289 return parser.emitError(parser.getCurrentLocation())
1290 << "invalid `use_device_addr` format";
1291
1292 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1293 args.useDevicePtrArgs)))
1294 return parser.emitError(parser.getCurrentLocation())
1295 << "invalid `use_device_addr` format";
1296
1297 return parser.parseRegion(region, entryBlockArgs);
1298}
1299
1300// These parseXyz functions correspond to the custom<Xyz> definitions
1301// in the .td file(s).
1302static ParseResult parseTargetOpRegion(
1303 OpAsmParser &parser, Region &region,
1305 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1307 SmallVectorImpl<Type> &hostEvalTypes,
1309 SmallVectorImpl<Type> &inReductionTypes,
1310 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1312 SmallVectorImpl<Type> &mapTypes,
1314 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1315 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1316 AllRegionParseArgs args;
1317 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1318 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1319 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1320 inReductionByref, inReductionSyms);
1321 args.mapArgs.emplace(mapVars, mapTypes);
1322 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1323 privateNeedsBarrier, &privateMaps);
1324 return parseBlockArgRegion(parser, region, args);
1325}
1326
1328 OpAsmParser &parser, Region &region,
1330 SmallVectorImpl<Type> &inReductionTypes,
1331 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1333 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1334 UnitAttr &privateNeedsBarrier) {
1335 AllRegionParseArgs args;
1336 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1337 inReductionByref, inReductionSyms);
1338 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1339 privateNeedsBarrier);
1340 return parseBlockArgRegion(parser, region, args);
1341}
1342
1344 OpAsmParser &parser, Region &region,
1346 SmallVectorImpl<Type> &inReductionTypes,
1347 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1349 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1350 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1352 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1353 ArrayAttr &reductionSyms) {
1354 AllRegionParseArgs args;
1355 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1356 inReductionByref, inReductionSyms);
1357 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1358 privateNeedsBarrier);
1359 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1360 reductionSyms, &reductionMod);
1361 return parseBlockArgRegion(parser, region, args);
1362}
1363
1364static ParseResult parsePrivateRegion(
1365 OpAsmParser &parser, Region &region,
1367 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1368 UnitAttr &privateNeedsBarrier) {
1369 AllRegionParseArgs args;
1370 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1371 privateNeedsBarrier);
1372 return parseBlockArgRegion(parser, region, args);
1373}
1374
1376 OpAsmParser &parser, Region &region,
1378 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1379 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1381 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1382 ArrayAttr &reductionSyms) {
1383 AllRegionParseArgs args;
1384 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1385 privateNeedsBarrier);
1386 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1387 reductionSyms, &reductionMod);
1388 return parseBlockArgRegion(parser, region, args);
1389}
1390
1391static ParseResult parseTaskReductionRegion(
1392 OpAsmParser &parser, Region &region,
1394 SmallVectorImpl<Type> &taskReductionTypes,
1395 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1396 AllRegionParseArgs args;
1397 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1398 taskReductionByref, taskReductionSyms);
1399 return parseBlockArgRegion(parser, region, args);
1400}
1401
1403 OpAsmParser &parser, Region &region,
1405 SmallVectorImpl<Type> &useDeviceAddrTypes,
1407 SmallVectorImpl<Type> &useDevicePtrTypes) {
1408 AllRegionParseArgs args;
1409 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1410 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1411 return parseBlockArgRegion(parser, region, args);
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// Printers for operations including clauses that define entry block arguments.
1416//===----------------------------------------------------------------------===//
1417
1418namespace {
1419struct MapPrintArgs {
1420 ValueRange vars;
1421 TypeRange types;
1422 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1423};
1424struct PrivatePrintArgs {
1425 ValueRange vars;
1426 TypeRange types;
1427 ArrayAttr syms;
1428 UnitAttr needsBarrier;
1429 DenseI64ArrayAttr mapIndices;
1430 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1431 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1432 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1433 mapIndices(mapIndices) {}
1434};
1435struct ReductionPrintArgs {
1436 ValueRange vars;
1437 TypeRange types;
1438 DenseBoolArrayAttr byref;
1439 ArrayAttr syms;
1440 ReductionModifierAttr modifier;
1441 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1442 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1443 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1444};
1445struct AllRegionPrintArgs {
1446 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1447 std::optional<MapPrintArgs> hostEvalArgs;
1448 std::optional<ReductionPrintArgs> inReductionArgs;
1449 std::optional<MapPrintArgs> mapArgs;
1450 std::optional<PrivatePrintArgs> privateArgs;
1451 std::optional<ReductionPrintArgs> reductionArgs;
1452 std::optional<ReductionPrintArgs> taskReductionArgs;
1453 std::optional<MapPrintArgs> useDeviceAddrArgs;
1454 std::optional<MapPrintArgs> useDevicePtrArgs;
1455};
1456} // namespace
1457
1459 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1460 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1461 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1462 DenseBoolArrayAttr byref = nullptr,
1463 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1464 if (argsSubrange.empty())
1465 return;
1466
1467 p << clauseName << "(";
1468
1469 if (modifier)
1470 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1471
1472 if (!symbols) {
1473 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1474 symbols = ArrayAttr::get(ctx, values);
1475 }
1476
1477 if (!mapIndices) {
1478 llvm::SmallVector<int64_t> values(operands.size(), -1);
1479 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1480 }
1481
1482 if (!byref) {
1483 mlir::SmallVector<bool> values(operands.size(), false);
1484 byref = DenseBoolArrayAttr::get(ctx, values);
1485 }
1486
1487 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1488 mapIndices.asArrayRef(),
1489 byref.asArrayRef()),
1490 p, [&p](auto t) {
1491 auto [op, arg, sym, map, isByRef] = t;
1492 if (isByRef)
1493 p << "byref ";
1494 if (sym)
1495 p << sym << " ";
1496
1497 p << op << " -> " << arg;
1498
1499 if (map != -1)
1500 p << " [map_idx=" << map << "]";
1501 });
1502 p << " : ";
1503 llvm::interleaveComma(types, p);
1504 p << ") ";
1505
1506 if (needsBarrier)
1507 p << getPrivateNeedsBarrierSpelling() << " ";
1508}
1509
1511 StringRef clauseName, ValueRange argsSubrange,
1512 std::optional<MapPrintArgs> mapArgs) {
1513 if (mapArgs)
1514 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1515 mapArgs->types);
1516}
1517
1519 StringRef clauseName, ValueRange argsSubrange,
1520 std::optional<PrivatePrintArgs> privateArgs) {
1521 if (privateArgs)
1523 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1524 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1525 /*modifier=*/nullptr, privateArgs->needsBarrier);
1526}
1527
1528static void
1529printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1530 ValueRange argsSubrange,
1531 std::optional<ReductionPrintArgs> reductionArgs) {
1532 if (reductionArgs)
1533 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1534 reductionArgs->vars, reductionArgs->types,
1535 reductionArgs->syms, /*mapIndices=*/nullptr,
1536 reductionArgs->byref, reductionArgs->modifier);
1537}
1538
1540 const AllRegionPrintArgs &args) {
1541 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1542 MLIRContext *ctx = op->getContext();
1543
1544 printBlockArgClause(p, ctx, "has_device_addr",
1545 iface.getHasDeviceAddrBlockArgs(),
1546 args.hasDeviceAddrArgs);
1547 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1548 args.hostEvalArgs);
1549 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1550 args.inReductionArgs);
1551 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1552 args.mapArgs);
1553 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1554 args.privateArgs);
1555 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1556 args.reductionArgs);
1557 printBlockArgClause(p, ctx, "task_reduction",
1558 iface.getTaskReductionBlockArgs(),
1559 args.taskReductionArgs);
1560 printBlockArgClause(p, ctx, "use_device_addr",
1561 iface.getUseDeviceAddrBlockArgs(),
1562 args.useDeviceAddrArgs);
1563 printBlockArgClause(p, ctx, "use_device_ptr",
1564 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1565
1566 p.printRegion(region, /*printEntryBlockArgs=*/false);
1567}
1568
1569// These parseXyz functions correspond to the custom<Xyz> definitions
1570// in the .td file(s).
1572 OpAsmPrinter &p, Operation *op, Region &region,
1573 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1574 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1575 ValueRange inReductionVars, TypeRange inReductionTypes,
1576 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1577 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1578 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1579 DenseI64ArrayAttr privateMaps) {
1580 AllRegionPrintArgs args;
1581 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1582 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1583 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1584 inReductionByref, inReductionSyms);
1585 args.mapArgs.emplace(mapVars, mapTypes);
1586 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1587 privateNeedsBarrier, privateMaps);
1588 printBlockArgRegion(p, op, region, args);
1589}
1590
1592 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1593 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1594 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1595 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1596 AllRegionPrintArgs args;
1597 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1598 inReductionByref, inReductionSyms);
1599 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1600 privateNeedsBarrier,
1601 /*mapIndices=*/nullptr);
1602 printBlockArgRegion(p, op, region, args);
1603}
1604
1606 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1607 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1608 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1609 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1610 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1611 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1612 ArrayAttr reductionSyms) {
1613 AllRegionPrintArgs args;
1614 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1615 inReductionByref, inReductionSyms);
1616 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1617 privateNeedsBarrier,
1618 /*mapIndices=*/nullptr);
1619 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1620 reductionSyms, reductionMod);
1621 printBlockArgRegion(p, op, region, args);
1622}
1623
1625 ValueRange privateVars, TypeRange privateTypes,
1626 ArrayAttr privateSyms,
1627 UnitAttr privateNeedsBarrier) {
1628 AllRegionPrintArgs args;
1629 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1630 privateNeedsBarrier,
1631 /*mapIndices=*/nullptr);
1632 printBlockArgRegion(p, op, region, args);
1633}
1634
1636 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1637 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1638 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1639 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1640 ArrayAttr reductionSyms) {
1641 AllRegionPrintArgs args;
1642 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1643 privateNeedsBarrier,
1644 /*mapIndices=*/nullptr);
1645 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1646 reductionSyms, reductionMod);
1647 printBlockArgRegion(p, op, region, args);
1648}
1649
1651 Region &region,
1652 ValueRange taskReductionVars,
1653 TypeRange taskReductionTypes,
1654 DenseBoolArrayAttr taskReductionByref,
1655 ArrayAttr taskReductionSyms) {
1656 AllRegionPrintArgs args;
1657 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1658 taskReductionByref, taskReductionSyms);
1659 printBlockArgRegion(p, op, region, args);
1660}
1661
1663 Region &region,
1664 ValueRange useDeviceAddrVars,
1665 TypeRange useDeviceAddrTypes,
1666 ValueRange useDevicePtrVars,
1667 TypeRange useDevicePtrTypes) {
1668 AllRegionPrintArgs args;
1669 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1670 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1671 printBlockArgRegion(p, op, region, args);
1672}
1673
1674template <typename ParsePrefixFn>
1675static ParseResult parseSplitIteratedList(
1676 OpAsmParser &parser,
1678 SmallVectorImpl<Type> &iteratedTypes,
1680 SmallVectorImpl<Type> &plainTypes, ParsePrefixFn &&parsePrefix) {
1681
1682 return parser.parseCommaSeparatedList([&]() -> ParseResult {
1683 if (failed(parsePrefix()))
1684 return failure();
1685
1687 Type ty;
1688 if (parser.parseOperand(v) || parser.parseColonType(ty))
1689 return failure();
1690
1691 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1692 iteratedVars.push_back(v);
1693 iteratedTypes.push_back(ty);
1694 } else {
1695 plainVars.push_back(v);
1696 plainTypes.push_back(ty);
1697 }
1698 return success();
1699 });
1700}
1701
1702template <typename PrintPrefixFn>
1704 TypeRange iteratedTypes,
1705 ValueRange plainVars, TypeRange plainTypes,
1706 PrintPrefixFn &&printPrefixForPlain,
1707 PrintPrefixFn &&printPrefixForIterated) {
1708
1709 bool first = true;
1710 auto emit = [&](Value v, Type t, auto &&printPrefix) {
1711 if (!first)
1712 p << ", ";
1713 printPrefix(v, t);
1714 p << v << " : " << t;
1715 first = false;
1716 };
1717
1718 for (unsigned i = 0; i < iteratedVars.size(); ++i)
1719 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1720 for (unsigned i = 0; i < plainVars.size(); ++i)
1721 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1722}
1723
1724/// Verifies Reduction Clause
1725static LogicalResult
1726verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1727 OperandRange reductionVars,
1728 std::optional<ArrayRef<bool>> reductionByref) {
1729 if (!reductionVars.empty()) {
1730 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1731 return op->emitOpError()
1732 << "expected as many reduction symbol references "
1733 "as reduction variables";
1734 if (reductionByref && reductionByref->size() != reductionVars.size())
1735 return op->emitError() << "expected as many reduction variable by "
1736 "reference attributes as reduction variables";
1737 } else {
1738 if (reductionSyms)
1739 return op->emitOpError() << "unexpected reduction symbol references";
1740 return success();
1741 }
1742
1743 // TODO: The followings should be done in
1744 // SymbolUserOpInterface::verifySymbolUses.
1745 DenseSet<Value> accumulators;
1746 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1747 Value accum = std::get<0>(args);
1748
1749 if (!accumulators.insert(accum).second)
1750 return op->emitOpError() << "accumulator variable used more than once";
1751
1752 Type varType = accum.getType();
1753 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1754 auto decl =
1756 if (!decl)
1757 return op->emitOpError() << "expected symbol reference " << symbolRef
1758 << " to point to a reduction declaration";
1759
1760 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1761 return op->emitOpError()
1762 << "expected accumulator (" << varType
1763 << ") to be the same type as reduction declaration ("
1764 << decl.getAccumulatorType() << ")";
1765 }
1766
1767 return success();
1768}
1769
1770//===----------------------------------------------------------------------===//
1771// Parser, printer and verifier for Copyprivate
1772//===----------------------------------------------------------------------===//
1773
1774/// copyprivate-entry-list ::= copyprivate-entry
1775/// | copyprivate-entry-list `,` copyprivate-entry
1776/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1777static ParseResult parseCopyprivate(
1778 OpAsmParser &parser,
1780 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1782 if (failed(parser.parseCommaSeparatedList([&]() {
1783 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1784 parser.parseArrow() ||
1785 parser.parseAttribute(symsVec.emplace_back()) ||
1786 parser.parseColonType(copyprivateTypes.emplace_back()))
1787 return failure();
1788 return success();
1789 })))
1790 return failure();
1791 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1792 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1793 return success();
1794}
1795
1796/// Print Copyprivate clause
1798 OperandRange copyprivateVars,
1799 TypeRange copyprivateTypes,
1800 std::optional<ArrayAttr> copyprivateSyms) {
1801 if (!copyprivateSyms.has_value())
1802 return;
1803 llvm::interleaveComma(
1804 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1805 [&](const auto &args) {
1806 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1807 << std::get<2>(args);
1808 });
1809}
1810
1811/// Verifies CopyPrivate Clause
1812static LogicalResult
1814 std::optional<ArrayAttr> copyprivateSyms) {
1815 size_t copyprivateSymsSize =
1816 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1817 if (copyprivateSymsSize != copyprivateVars.size())
1818 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1819 << copyprivateVars.size()
1820 << ") and functions (= " << copyprivateSymsSize
1821 << "), both must be equal";
1822 if (!copyprivateSyms.has_value())
1823 return success();
1824
1825 for (auto copyprivateVarAndSym :
1826 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1827 auto symbolRef =
1828 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1829 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1830 funcOp;
1831 if (mlir::func::FuncOp mlirFuncOp =
1833 symbolRef))
1834 funcOp = mlirFuncOp;
1835 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1837 op, symbolRef))
1838 funcOp = llvmFuncOp;
1839
1840 auto getNumArguments = [&] {
1841 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1842 };
1843
1844 auto getArgumentType = [&](unsigned i) {
1845 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1846 *funcOp);
1847 };
1848
1849 if (!funcOp)
1850 return op->emitOpError() << "expected symbol reference " << symbolRef
1851 << " to point to a copy function";
1852
1853 if (getNumArguments() != 2)
1854 return op->emitOpError()
1855 << "expected copy function " << symbolRef << " to have 2 operands";
1856
1857 Type argTy = getArgumentType(0);
1858 if (argTy != getArgumentType(1))
1859 return op->emitOpError() << "expected copy function " << symbolRef
1860 << " arguments to have the same type";
1861
1862 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1863 if (argTy != varType)
1864 return op->emitOpError()
1865 << "expected copy function arguments' type (" << argTy
1866 << ") to be the same as copyprivate variable's type (" << varType
1867 << ")";
1868 }
1869
1870 return success();
1871}
1872
1873//===----------------------------------------------------------------------===//
1874// Parser, printer and verifier for DependVarList
1875//===----------------------------------------------------------------------===//
1876
1877/// depend-entry-list ::= depend-entry
1878/// | depend-entry-list `,` depend-entry
1879/// depend-entry ::= depend-kind `->` ssa-id `:` type
1880/// | depend-kind `->` ssa-id `:` iterated-type
1881static ParseResult parseDependVarList(
1882 OpAsmParser &parser,
1884 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds,
1886 SmallVectorImpl<Type> &iteratedTypes, ArrayAttr &iteratedKinds) {
1889 if (failed(parser.parseCommaSeparatedList([&]() {
1890 StringRef keyword;
1891 OpAsmParser::UnresolvedOperand operand;
1892 Type ty;
1893 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1894 parser.parseOperand(operand) || parser.parseColonType(ty))
1895 return failure();
1896 std::optional<ClauseTaskDepend> keywordDepend =
1897 symbolizeClauseTaskDepend(keyword);
1898 if (!keywordDepend)
1899 return failure();
1900 auto kindAttr =
1901 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1902 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1903 iteratedVars.push_back(operand);
1904 iteratedTypes.push_back(ty);
1905 iterKindsVec.push_back(kindAttr);
1906 } else {
1907 dependVars.push_back(operand);
1908 dependTypes.push_back(ty);
1909 kindsVec.push_back(kindAttr);
1910 }
1911 return success();
1912 })))
1913 return failure();
1914 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1915 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1916 SmallVector<Attribute> iterKinds(iterKindsVec.begin(), iterKindsVec.end());
1917 iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds);
1918 return success();
1919}
1920
1921/// Print Depend clause
1923 OperandRange dependVars, TypeRange dependTypes,
1924 std::optional<ArrayAttr> dependKinds,
1925 OperandRange iteratedVars,
1926 TypeRange iteratedTypes,
1927 std::optional<ArrayAttr> iteratedKinds) {
1928 bool first = true;
1929 auto printEntries = [&](OperandRange vars, TypeRange types,
1930 std::optional<ArrayAttr> kinds) {
1931 for (unsigned i = 0, e = vars.size(); i < e; ++i) {
1932 if (!first)
1933 p << ", ";
1934 p << stringifyClauseTaskDepend(
1935 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1936 .getValue())
1937 << " -> " << vars[i] << " : " << types[i];
1938 first = false;
1939 }
1940 };
1941 printEntries(dependVars, dependTypes, dependKinds);
1942 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1943}
1944
1945/// Verifies Depend clause
1946static LogicalResult verifyDependVarList(Operation *op,
1947 std::optional<ArrayAttr> dependKinds,
1948 OperandRange dependVars,
1949 std::optional<ArrayAttr> iteratedKinds,
1950 OperandRange iteratedVars) {
1951 if (!dependVars.empty()) {
1952 if (!dependKinds || dependKinds->size() != dependVars.size())
1953 return op->emitOpError() << "expected as many depend values"
1954 " as depend variables";
1955 } else {
1956 if (dependKinds && !dependKinds->empty())
1957 return op->emitOpError() << "unexpected depend values";
1958 }
1959
1960 if (!iteratedVars.empty()) {
1961 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1962 return op->emitOpError() << "expected as many depend iterated values"
1963 " as depend iterated variables";
1964 } else {
1965 if (iteratedKinds && !iteratedKinds->empty())
1966 return op->emitOpError() << "unexpected depend iterated values";
1967 }
1968
1969 return success();
1970}
1971
1972//===----------------------------------------------------------------------===//
1973// Parser, printer and verifier for Synchronization Hint (2.17.12)
1974//===----------------------------------------------------------------------===//
1975
1976/// Parses a Synchronization Hint clause. The value of hint is an integer
1977/// which is a combination of different hints from `omp_sync_hint_t`.
1978///
1979/// hint-clause = `hint` `(` hint-value `)`
1980static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1981 IntegerAttr &hintAttr) {
1982 StringRef hintKeyword;
1983 int64_t hint = 0;
1984 if (succeeded(parser.parseOptionalKeyword("none"))) {
1985 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1986 return success();
1987 }
1988 auto parseKeyword = [&]() -> ParseResult {
1989 if (failed(parser.parseKeyword(&hintKeyword)))
1990 return failure();
1991 if (hintKeyword == "uncontended")
1992 hint |= 1;
1993 else if (hintKeyword == "contended")
1994 hint |= 2;
1995 else if (hintKeyword == "nonspeculative")
1996 hint |= 4;
1997 else if (hintKeyword == "speculative")
1998 hint |= 8;
1999 else
2000 return parser.emitError(parser.getCurrentLocation())
2001 << hintKeyword << " is not a valid hint";
2002 return success();
2003 };
2004 if (parser.parseCommaSeparatedList(parseKeyword))
2005 return failure();
2006 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
2007 return success();
2008}
2009
2010/// Prints a Synchronization Hint clause
2012 IntegerAttr hintAttr) {
2013 int64_t hint = hintAttr.getInt();
2014
2015 if (hint == 0) {
2016 p << "none";
2017 return;
2018 }
2019
2020 // Helper function to get n-th bit from the right end of `value`
2021 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
2022
2023 bool uncontended = bitn(hint, 0);
2024 bool contended = bitn(hint, 1);
2025 bool nonspeculative = bitn(hint, 2);
2026 bool speculative = bitn(hint, 3);
2027
2029 if (uncontended)
2030 hints.push_back("uncontended");
2031 if (contended)
2032 hints.push_back("contended");
2033 if (nonspeculative)
2034 hints.push_back("nonspeculative");
2035 if (speculative)
2036 hints.push_back("speculative");
2037
2038 llvm::interleaveComma(hints, p);
2039}
2040
2041/// Verifies a synchronization hint clause
2042static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
2043
2044 // Helper function to get n-th bit from the right end of `value`
2045 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
2046
2047 bool uncontended = bitn(hint, 0);
2048 bool contended = bitn(hint, 1);
2049 bool nonspeculative = bitn(hint, 2);
2050 bool speculative = bitn(hint, 3);
2051
2052 if (uncontended && contended)
2053 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
2054 "omp_sync_hint_contended cannot be combined";
2055 if (nonspeculative && speculative)
2056 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
2057 "omp_sync_hint_speculative cannot be combined.";
2058 return success();
2059}
2060
2061//===----------------------------------------------------------------------===//
2062// Parser, printer and verifier for Target
2063//===----------------------------------------------------------------------===//
2064
2065// Helper function to get bitwise AND of `value` and 'flag' then return it as a
2066// boolean
2067static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
2068 return (value & flag) == flag;
2069}
2070
2071/// Parses a map_entries map type from a string format back into its numeric
2072/// value.
2073///
2074/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
2075/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
2076static ParseResult parseMapClause(OpAsmParser &parser,
2077 ClauseMapFlagsAttr &mapType) {
2078 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
2079 // This simply verifies the correct keyword is read in, the
2080 // keyword itself is stored inside of the operation
2081 auto parseTypeAndMod = [&]() -> ParseResult {
2082 StringRef mapTypeMod;
2083 if (parser.parseKeyword(&mapTypeMod))
2084 return failure();
2085
2086 if (mapTypeMod == "always")
2087 mapTypeBits |= ClauseMapFlags::always;
2088
2089 if (mapTypeMod == "implicit")
2090 mapTypeBits |= ClauseMapFlags::implicit;
2091
2092 if (mapTypeMod == "ompx_hold")
2093 mapTypeBits |= ClauseMapFlags::ompx_hold;
2094
2095 if (mapTypeMod == "close")
2096 mapTypeBits |= ClauseMapFlags::close;
2097
2098 if (mapTypeMod == "present")
2099 mapTypeBits |= ClauseMapFlags::present;
2100
2101 if (mapTypeMod == "to")
2102 mapTypeBits |= ClauseMapFlags::to;
2103
2104 if (mapTypeMod == "from")
2105 mapTypeBits |= ClauseMapFlags::from;
2106
2107 if (mapTypeMod == "tofrom")
2108 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
2109
2110 if (mapTypeMod == "delete")
2111 mapTypeBits |= ClauseMapFlags::del;
2112
2113 if (mapTypeMod == "storage")
2114 mapTypeBits |= ClauseMapFlags::storage;
2115
2116 if (mapTypeMod == "return_param")
2117 mapTypeBits |= ClauseMapFlags::return_param;
2118
2119 if (mapTypeMod == "private")
2120 mapTypeBits |= ClauseMapFlags::priv;
2121
2122 if (mapTypeMod == "literal")
2123 mapTypeBits |= ClauseMapFlags::literal;
2124
2125 if (mapTypeMod == "attach")
2126 mapTypeBits |= ClauseMapFlags::attach;
2127
2128 if (mapTypeMod == "attach_always")
2129 mapTypeBits |= ClauseMapFlags::attach_always;
2130
2131 if (mapTypeMod == "attach_never")
2132 mapTypeBits |= ClauseMapFlags::attach_never;
2133
2134 if (mapTypeMod == "attach_auto")
2135 mapTypeBits |= ClauseMapFlags::attach_auto;
2136
2137 if (mapTypeMod == "ref_ptr")
2138 mapTypeBits |= ClauseMapFlags::ref_ptr;
2139
2140 if (mapTypeMod == "ref_ptee")
2141 mapTypeBits |= ClauseMapFlags::ref_ptee;
2142
2143 if (mapTypeMod == "ref_ptr_ptee")
2144 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
2145
2146 if (mapTypeMod == "is_device_ptr")
2147 mapTypeBits |= ClauseMapFlags::is_device_ptr;
2148
2149 return success();
2150 };
2151
2152 if (parser.parseCommaSeparatedList(parseTypeAndMod))
2153 return failure();
2154
2155 mapType =
2156 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
2157
2158 return success();
2159}
2160
2161/// Prints a map_entries map type from its numeric value out into its string
2162/// format.
2163static void printMapClause(OpAsmPrinter &p, Operation *op,
2164 ClauseMapFlagsAttr mapType) {
2166 ClauseMapFlags mapFlags = mapType.getValue();
2167
2168 // handling of always, close, present placed at the beginning of the string
2169 // to aid readability
2170 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
2171 mapTypeStrs.push_back("always");
2172 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
2173 mapTypeStrs.push_back("implicit");
2174 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
2175 mapTypeStrs.push_back("ompx_hold");
2176 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
2177 mapTypeStrs.push_back("close");
2178 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
2179 mapTypeStrs.push_back("present");
2180
2181 // special handling of to/from/tofrom/delete and release/alloc, release +
2182 // alloc are the abscense of one of the other flags, whereas tofrom requires
2183 // both the to and from flag to be set.
2184 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
2185 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
2186
2187 if (to && from)
2188 mapTypeStrs.push_back("tofrom");
2189 else if (from)
2190 mapTypeStrs.push_back("from");
2191 else if (to)
2192 mapTypeStrs.push_back("to");
2193
2194 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
2195 mapTypeStrs.push_back("delete");
2196 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
2197 mapTypeStrs.push_back("return_param");
2198 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
2199 mapTypeStrs.push_back("storage");
2200 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
2201 mapTypeStrs.push_back("private");
2202 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
2203 mapTypeStrs.push_back("literal");
2204 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
2205 mapTypeStrs.push_back("attach");
2206 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
2207 mapTypeStrs.push_back("attach_always");
2208 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
2209 mapTypeStrs.push_back("attach_never");
2210 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
2211 mapTypeStrs.push_back("attach_auto");
2212 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
2213 mapTypeStrs.push_back("ref_ptr");
2214 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
2215 mapTypeStrs.push_back("ref_ptee");
2216 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
2217 mapTypeStrs.push_back("ref_ptr_ptee");
2218 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
2219 mapTypeStrs.push_back("is_device_ptr");
2220 if (mapFlags == ClauseMapFlags::none)
2221 mapTypeStrs.push_back("none");
2222
2223 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2224 p << mapTypeStrs[i];
2225 if (i + 1 < mapTypeStrs.size()) {
2226 p << ", ";
2227 }
2228 }
2229}
2230
2231static ParseResult parseMembersIndex(OpAsmParser &parser,
2232 ArrayAttr &membersIdx) {
2233 SmallVector<Attribute> values, memberIdxs;
2234
2235 auto parseIndices = [&]() -> ParseResult {
2236 int64_t value;
2237 if (parser.parseInteger(value))
2238 return failure();
2239 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
2240 APInt(64, value, /*isSigned=*/false)));
2241 return success();
2242 };
2243
2244 do {
2245 if (failed(parser.parseLSquare()))
2246 return failure();
2247
2248 if (parser.parseCommaSeparatedList(parseIndices))
2249 return failure();
2250
2251 if (failed(parser.parseRSquare()))
2252 return failure();
2253
2254 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
2255 values.clear();
2256 } while (succeeded(parser.parseOptionalComma()));
2257
2258 if (!memberIdxs.empty())
2259 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
2260
2261 return success();
2262}
2263
2264static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
2265 ArrayAttr membersIdx) {
2266 if (!membersIdx)
2267 return;
2268
2269 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
2270 p << "[";
2271 auto memberIdx = cast<ArrayAttr>(v);
2272 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
2273 p << cast<IntegerAttr>(v2).getInt();
2274 });
2275 p << "]";
2276 });
2277}
2278
2280 VariableCaptureKindAttr mapCaptureType) {
2281 std::string typeCapStr;
2282 llvm::raw_string_ostream typeCap(typeCapStr);
2283 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2284 typeCap << "ByRef";
2285 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2286 typeCap << "ByCopy";
2287 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2288 typeCap << "VLAType";
2289 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2290 typeCap << "This";
2291 p << typeCapStr;
2292}
2293
2294static ParseResult parseCaptureType(OpAsmParser &parser,
2295 VariableCaptureKindAttr &mapCaptureType) {
2296 StringRef mapCaptureKey;
2297 if (parser.parseKeyword(&mapCaptureKey))
2298 return failure();
2299
2300 if (mapCaptureKey == "This")
2301 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2302 parser.getContext(), mlir::omp::VariableCaptureKind::This);
2303 if (mapCaptureKey == "ByRef")
2304 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2305 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
2306 if (mapCaptureKey == "ByCopy")
2307 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2308 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2309 if (mapCaptureKey == "VLAType")
2310 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2311 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
2312
2313 return success();
2314}
2315
2316static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
2319
2320 for (auto mapOp : mapVars) {
2321 if (!mapOp.getDefiningOp())
2322 return emitError(op->getLoc(), "missing map operation");
2323
2324 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2325 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2326
2327 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2328 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2329 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2330
2331 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2332 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2333 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2334
2335 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2336 return emitError(op->getLoc(),
2337 "to, from, tofrom and alloc map types are permitted");
2338
2339 if (isa<TargetEnterDataOp>(op) && (from || del))
2340 return emitError(op->getLoc(), "to and alloc map types are permitted");
2341
2342 if (isa<TargetExitDataOp>(op) && to)
2343 return emitError(op->getLoc(),
2344 "from, release and delete map types are permitted");
2345
2346 if (isa<TargetUpdateOp>(op)) {
2347 if (del) {
2348 return emitError(op->getLoc(),
2349 "at least one of to or from map types must be "
2350 "specified, other map types are not permitted");
2351 }
2352
2353 if (!to && !from) {
2354 return emitError(op->getLoc(),
2355 "at least one of to or from map types must be "
2356 "specified, other map types are not permitted");
2357 }
2358
2359 auto updateVar = mapInfoOp.getVarPtr();
2360
2361 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2362 (from && updateToVars.contains(updateVar))) {
2363 return emitError(
2364 op->getLoc(),
2365 "either to or from map types can be specified, not both");
2366 }
2367
2368 if (always || close || implicit) {
2369 return emitError(
2370 op->getLoc(),
2371 "present, mapper and iterator map type modifiers are permitted");
2372 }
2373
2374 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2375 }
2376 } else if (!isa<DeclareMapperInfoOp>(op)) {
2377 return emitError(op->getLoc(),
2378 "map argument is not a map entry operation");
2379 }
2380 }
2381
2382 return success();
2383}
2384
2385template <typename OpType>
2386static LogicalResult verifyPrivateVarList(OpType &op);
2387
2388static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2389 std::optional<DenseI64ArrayAttr> privateMapIndices =
2390 targetOp.getPrivateMapsAttr();
2391
2392 // None of the private operands are mapped.
2393 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2394 return success();
2395
2396 OperandRange privateVars = targetOp.getPrivateVars();
2397
2398 if (privateMapIndices.value().size() !=
2399 static_cast<int64_t>(privateVars.size()))
2400 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2401 "`private_maps` attribute mismatch");
2402
2403 return success();
2404}
2405
2406//===----------------------------------------------------------------------===//
2407// MapInfoOp
2408//===----------------------------------------------------------------------===//
2409
2410static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2411 StringRef clauseName,
2412 OperandRange vars) {
2413 for (Value var : vars)
2414 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2415 return op->emitOpError()
2416 << "'" << clauseName
2417 << "' arguments must be defined by 'omp.map.info' ops";
2418 return success();
2419}
2420
2421LogicalResult MapInfoOp::verify() {
2422 if (getMapperId() &&
2424 *this, getMapperIdAttr())) {
2425 return emitError("invalid mapper id");
2426 }
2427
2428 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2429 return failure();
2430
2431 return success();
2432}
2433
2434//===----------------------------------------------------------------------===//
2435// TargetDataOp
2436//===----------------------------------------------------------------------===//
2437
2438void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2439 const TargetDataOperands &clauses) {
2440 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2441 clauses.mapVars, clauses.useDeviceAddrVars,
2442 clauses.useDevicePtrVars);
2443}
2444
2445LogicalResult TargetDataOp::verify() {
2446 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2447 getUseDeviceAddrVars().empty()) {
2448 return ::emitError(this->getLoc(),
2449 "At least one of map, use_device_ptr_vars, or "
2450 "use_device_addr_vars operand must be present");
2451 }
2452
2453 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2454 getUseDevicePtrVars())))
2455 return failure();
2456
2457 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2458 getUseDeviceAddrVars())))
2459 return failure();
2460
2461 return verifyMapClause(*this, getMapVars());
2462}
2463
2464//===----------------------------------------------------------------------===//
2465// TargetEnterDataOp
2466//===----------------------------------------------------------------------===//
2467
2468void TargetEnterDataOp::build(
2469 OpBuilder &builder, OperationState &state,
2470 const TargetEnterExitUpdateDataOperands &clauses) {
2471 MLIRContext *ctx = builder.getContext();
2472 TargetEnterDataOp::build(
2473 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2474 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2475 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2476 clauses.nowait);
2477}
2478
2479LogicalResult TargetEnterDataOp::verify() {
2480 LogicalResult verifyDependVars =
2481 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2482 getDependIteratedKinds(), getDependIterated());
2483 return failed(verifyDependVars) ? verifyDependVars
2484 : verifyMapClause(*this, getMapVars());
2485}
2486
2487//===----------------------------------------------------------------------===//
2488// TargetExitDataOp
2489//===----------------------------------------------------------------------===//
2490
2491void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2492 const TargetEnterExitUpdateDataOperands &clauses) {
2493 MLIRContext *ctx = builder.getContext();
2494 TargetExitDataOp::build(
2495 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2496 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2497 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2498 clauses.nowait);
2499}
2500
2501LogicalResult TargetExitDataOp::verify() {
2502 LogicalResult verifyDependVars =
2503 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2504 getDependIteratedKinds(), getDependIterated());
2505 return failed(verifyDependVars) ? verifyDependVars
2506 : verifyMapClause(*this, getMapVars());
2507}
2508
2509//===----------------------------------------------------------------------===//
2510// TargetUpdateOp
2511//===----------------------------------------------------------------------===//
2512
2513void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2514 const TargetEnterExitUpdateDataOperands &clauses) {
2515 MLIRContext *ctx = builder.getContext();
2516 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2517 clauses.dependVars,
2518 makeArrayAttr(ctx, clauses.dependIteratedKinds),
2519 clauses.dependIterated, clauses.device, clauses.ifExpr,
2520 clauses.mapVars, clauses.nowait);
2521}
2522
2523LogicalResult TargetUpdateOp::verify() {
2524 LogicalResult verifyDependVars =
2525 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2526 getDependIteratedKinds(), getDependIterated());
2527 return failed(verifyDependVars) ? verifyDependVars
2528 : verifyMapClause(*this, getMapVars());
2529}
2530
2531//===----------------------------------------------------------------------===//
2532// TargetOp
2533//===----------------------------------------------------------------------===//
2534
2535void TargetOp::build(OpBuilder &builder, OperationState &state,
2536 const TargetOperands &clauses) {
2537 MLIRContext *ctx = builder.getContext();
2538 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2539 // inReductionByref, inReductionSyms.
2540 TargetOp::build(
2541 builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
2542 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2543 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2544 clauses.device, clauses.dynGroupprivateAccessGroup,
2545 clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
2546 clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
2547 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2548 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2549 clauses.nowait, clauses.privateVars,
2550 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2551 clauses.threadLimitVars,
2552 /*private_maps=*/nullptr);
2553}
2554
2555LogicalResult TargetOp::verify() {
2556 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
2557 getDependIteratedKinds(),
2558 getDependIterated())))
2559 return failure();
2560
2561 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2562 getHasDeviceAddrVars())))
2563 return failure();
2564
2565 if (failed(verifyMapClause(*this, getMapVars())))
2566 return failure();
2567
2569 *this, getDynGroupprivateAccessGroupAttr(),
2570 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
2571 return failure();
2572
2573 if (failed(verifyPrivateVarList(*this)))
2574 return failure();
2575
2576 return verifyPrivateVarsMapping(*this);
2577}
2578
2579LogicalResult TargetOp::verifyRegions() {
2580 auto teamsOps = getOps<TeamsOp>();
2581 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2582 return emitError("target containing multiple 'omp.teams' nested ops");
2583
2584 // Check that host_eval values are only used in legal ways.
2585 bool hostEvalTripCount;
2586 Operation *capturedOp = getInnermostCapturedOmpOp();
2587 TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
2588 for (Value hostEvalArg :
2589 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2590 for (Operation *user : hostEvalArg.getUsers()) {
2591 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2592 // Check if used in num_teams_lower or any of num_teams_upper_vars
2593 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2594 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2595 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2596 continue;
2597
2598 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2599 "and 'thread_limit' in 'omp.teams'";
2600 }
2601 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2602 if (execMode == TargetExecMode::spmd &&
2603 parallelOp->isAncestor(capturedOp) &&
2604 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2605 continue;
2606
2607 return emitOpError()
2608 << "host_eval argument only legal as 'num_threads' in "
2609 "'omp.parallel' when representing target SPMD";
2610 }
2611 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2612 if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
2613 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2614 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2615 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2616 continue;
2617
2618 return emitOpError() << "host_eval argument only legal as loop bounds "
2619 "and steps in 'omp.loop_nest' when trip count "
2620 "must be evaluated in the host";
2621 }
2622
2623 return emitOpError() << "host_eval argument illegal use in '"
2624 << user->getName() << "' operation";
2625 }
2626 }
2627 return success();
2628}
2629
2630static Operation *
2631findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2632 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2633 assert(rootOp && "expected valid operation");
2634
2635 Dialect *ompDialect = rootOp->getDialect();
2636 Operation *capturedOp = nullptr;
2637 DominanceInfo domInfo;
2638
2639 // Process in pre-order to check operations from outermost to innermost,
2640 // ensuring we only enter the region of an operation if it meets the criteria
2641 // for being captured. We stop the exploration of nested operations as soon as
2642 // we process a region holding no operations to be captured.
2643 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2644 if (op == rootOp)
2645 return WalkResult::advance();
2646
2647 // Ignore operations of other dialects or omp operations with no regions,
2648 // because these will only be checked if they are siblings of an omp
2649 // operation that can potentially be captured.
2650 bool isOmpDialect = op->getDialect() == ompDialect;
2651 bool hasRegions = op->getNumRegions() > 0;
2652 if (!isOmpDialect || !hasRegions)
2653 return WalkResult::skip();
2654
2655 // This operation cannot be captured if it can be executed more than once
2656 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2657 // be executed before all exits of the region (i.e. it doesn't dominate all
2658 // blocks with no successors reachable from the entry block).
2659 if (checkSingleMandatoryExec) {
2660 Region *parentRegion = op->getParentRegion();
2661 Block *parentBlock = op->getBlock();
2662
2663 for (Block *successor : parentBlock->getSuccessors())
2664 if (successor->isReachable(parentBlock))
2665 return WalkResult::interrupt();
2666
2667 for (Block &block : *parentRegion)
2668 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2669 !domInfo.dominates(parentBlock, &block))
2670 return WalkResult::interrupt();
2671 }
2672
2673 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2674 // into nested operations.
2675 for (Operation &sibling : op->getParentRegion()->getOps())
2676 if (&sibling != op && !siblingAllowedFn(&sibling))
2677 return WalkResult::interrupt();
2678
2679 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2680 // Otherwise, process the contents of this operation.
2681 capturedOp = op;
2682 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2684 });
2685
2686 return capturedOp;
2687}
2688
2689Operation *TargetOp::getInnermostCapturedOmpOp() {
2690 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2691
2692 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2693 // effects, but don't include a memory write effect.
2694 return findCapturedOmpOp(
2695 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2696 if (!sibling)
2697 return false;
2698
2699 if (ompDialect == sibling->getDialect())
2700 return sibling->hasTrait<OpTrait::IsTerminator>();
2701
2702 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2704 effects;
2705 memOp.getEffects(effects);
2706 return !llvm::any_of(
2707 effects, [&](MemoryEffects::EffectInstance &effect) {
2708 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2709 isa<SideEffects::AutomaticAllocationScopeResource>(
2710 effect.getResource());
2711 });
2712 }
2713 return true;
2714 });
2715}
2716
2717/// Check if we can promote SPMD kernel to No-Loop kernel.
2718static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2719 WsloopOp *wsLoopOp) {
2720 // num_teams clause can break no-loop teams/threads assumption.
2721 if (!teamsOp.getNumTeamsUpperVars().empty())
2722 return false;
2723
2724 // Reduction kernels are slower in no-loop mode.
2725 if (teamsOp.getNumReductionVars())
2726 return false;
2727 if (wsLoopOp->getNumReductionVars())
2728 return false;
2729
2730 // Check if the user allows the promotion of kernels to no-loop mode.
2731 OffloadModuleInterface offloadMod =
2732 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2733 if (!offloadMod)
2734 return false;
2735 auto ompFlags = offloadMod.getFlags();
2736 if (!ompFlags)
2737 return false;
2738 return ompFlags.getAssumeTeamsOversubscription() &&
2739 ompFlags.getAssumeThreadsOversubscription();
2740}
2741
2742TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
2743 bool *hostEvalTripCount) {
2744 // TODO: Support detection of bare kernel mode.
2745 // A non-null captured op is only valid if it resides inside of a TargetOp
2746 // and is the result of calling getInnermostCapturedOmpOp() on it.
2747 TargetOp targetOp =
2748 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2749 assert((!capturedOp ||
2750 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2751 "unexpected captured op");
2752
2753 if (hostEvalTripCount)
2754 *hostEvalTripCount = false;
2755
2756 // If it's not capturing a loop, it's a default target region.
2757 if (!isa_and_present<LoopNestOp>(capturedOp))
2758 return TargetExecMode::generic;
2759
2760 // Get the innermost non-simd loop wrapper.
2762 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2763 assert(!loopWrappers.empty());
2764
2765 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2766 if (isa<SimdOp>(innermostWrapper))
2767 innermostWrapper = std::next(innermostWrapper);
2768
2769 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2770 if (numWrappers != 1 && numWrappers != 2)
2771 return TargetExecMode::generic;
2772
2773 // Detect target-teams-distribute-parallel-wsloop[-simd].
2774 if (numWrappers == 2) {
2775 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2776 if (!wsloopOp)
2777 return TargetExecMode::generic;
2778
2779 innermostWrapper = std::next(innermostWrapper);
2780 if (!isa<DistributeOp>(innermostWrapper))
2781 return TargetExecMode::generic;
2782
2783 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2784 if (!isa_and_present<ParallelOp>(parallelOp))
2785 return TargetExecMode::generic;
2786
2787 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2788 if (!teamsOp)
2789 return TargetExecMode::generic;
2790
2791 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2792 TargetExecMode result = TargetExecMode::spmd;
2793 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2794 result = TargetExecMode::no_loop;
2795 if (hostEvalTripCount)
2796 *hostEvalTripCount = true;
2797 return result;
2798 }
2799 }
2800 // Detect target-teams-distribute[-simd] and target-teams-loop.
2801 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2802 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2803 if (!isa_and_present<TeamsOp>(teamsOp))
2804 return TargetExecMode::generic;
2805
2806 if (teamsOp->getParentOp() != targetOp.getOperation())
2807 return TargetExecMode::generic;
2808
2809 if (hostEvalTripCount)
2810 *hostEvalTripCount = true;
2811
2812 if (isa<LoopOp>(innermostWrapper))
2813 return TargetExecMode::spmd;
2814
2815 return TargetExecMode::generic;
2816 }
2817 // Detect target-parallel-wsloop[-simd].
2818 else if (isa<WsloopOp>(innermostWrapper)) {
2819 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2820 if (!isa_and_present<ParallelOp>(parallelOp))
2821 return TargetExecMode::generic;
2822
2823 if (parallelOp->getParentOp() == targetOp.getOperation())
2824 return TargetExecMode::spmd;
2825 }
2826
2827 return TargetExecMode::generic;
2828}
2829
2830//===----------------------------------------------------------------------===//
2831// ParallelOp
2832//===----------------------------------------------------------------------===//
2833
2834void ParallelOp::build(OpBuilder &builder, OperationState &state,
2835 ArrayRef<NamedAttribute> attributes) {
2836 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2837 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2838 /*num_threads_vars=*/ValueRange(),
2839 /*private_vars=*/ValueRange(),
2840 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2841 /*proc_bind_kind=*/nullptr,
2842 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2843 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2844 state.addAttributes(attributes);
2845}
2846
2847void ParallelOp::build(OpBuilder &builder, OperationState &state,
2848 const ParallelOperands &clauses) {
2849 MLIRContext *ctx = builder.getContext();
2850 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2851 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2852 makeArrayAttr(ctx, clauses.privateSyms),
2853 clauses.privateNeedsBarrier, clauses.procBindKind,
2854 clauses.reductionMod, clauses.reductionVars,
2855 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2856 makeArrayAttr(ctx, clauses.reductionSyms));
2857}
2858
2859template <typename OpType>
2860static LogicalResult verifyPrivateVarList(OpType &op) {
2861 auto privateVars = op.getPrivateVars();
2862 auto privateSyms = op.getPrivateSymsAttr();
2863
2864 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2865 return success();
2866
2867 auto numPrivateVars = privateVars.size();
2868 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2869
2870 if (numPrivateVars != numPrivateSyms)
2871 return op.emitError() << "inconsistent number of private variables and "
2872 "privatizer op symbols, private vars: "
2873 << numPrivateVars
2874 << " vs. privatizer op symbols: " << numPrivateSyms;
2875
2876 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2877 Type varType = std::get<0>(privateVarInfo).getType();
2878 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2879 PrivateClauseOp privatizerOp =
2881
2882 if (privatizerOp == nullptr)
2883 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2884 << privateSym << "'";
2885
2886 Type privatizerType = privatizerOp.getArgType();
2887
2888 if (privatizerType && (varType != privatizerType))
2889 return op.emitError()
2890 << "type mismatch between a "
2891 << (privatizerOp.getDataSharingType() ==
2892 DataSharingClauseType::Private
2893 ? "private"
2894 : "firstprivate")
2895 << " variable and its privatizer op, var type: " << varType
2896 << " vs. privatizer op type: " << privatizerType;
2897 }
2898
2899 return success();
2900}
2901
2902LogicalResult ParallelOp::verify() {
2903 if (getAllocateVars().size() != getAllocatorVars().size())
2904 return emitError(
2905 "expected equal sizes for allocate and allocator variables");
2906
2907 if (failed(verifyPrivateVarList(*this)))
2908 return failure();
2909
2910 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2911 getReductionByref());
2912}
2913
2914LogicalResult ParallelOp::verifyRegions() {
2915 auto distChildOps = getOps<DistributeOp>();
2916 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2917 if (numDistChildOps > 1)
2918 return emitError()
2919 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2920
2921 if (numDistChildOps == 1) {
2922 if (!isComposite())
2923 return emitError()
2924 << "'omp.composite' attribute missing from composite operation";
2925
2926 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2927 Operation &distributeOp = **distChildOps.begin();
2928 for (Operation &childOp : getOps()) {
2929 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2930 continue;
2931
2932 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2933 return emitError() << "unexpected OpenMP operation inside of composite "
2934 "'omp.parallel': "
2935 << childOp.getName();
2936 }
2937 } else if (isComposite()) {
2938 return emitError()
2939 << "'omp.composite' attribute present in non-composite operation";
2940 }
2941 return success();
2942}
2943
2944//===----------------------------------------------------------------------===//
2945// TeamsOp
2946//===----------------------------------------------------------------------===//
2947
2949 while ((op = op->getParentOp()))
2950 if (isa<OpenMPDialect>(op->getDialect()))
2951 return false;
2952 return true;
2953}
2954
2955void TeamsOp::build(OpBuilder &builder, OperationState &state,
2956 const TeamsOperands &clauses) {
2957 MLIRContext *ctx = builder.getContext();
2958 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2959 TeamsOp::build(
2960 builder, state, clauses.allocateVars, clauses.allocatorVars,
2961 clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
2962 clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
2963 clauses.numTeamsUpperVars, /*private_vars=*/{}, /*private_syms=*/nullptr,
2964 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2965 clauses.reductionVars,
2966 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2967 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2968}
2969
2970// Verify num_teams clause
2971static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
2972 OperandRange numTeamsUpperVars) {
2973 // If lower is specified, upper must have exactly one value
2974 if (numTeamsLower) {
2975 if (numTeamsUpperVars.size() != 1)
2976 return op->emitError(
2977 "expected exactly one num_teams upper bound when lower bound is "
2978 "specified");
2979 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
2980 return op->emitError(
2981 "expected num_teams upper bound and lower bound to be "
2982 "the same type");
2983 }
2984
2985 return success();
2986}
2987
2988LogicalResult TeamsOp::verify() {
2989 // Check parent region
2990 // TODO If nested inside of a target region, also check that it does not
2991 // contain any statements, declarations or directives other than this
2992 // omp.teams construct. The issue is how to support the initialization of
2993 // this operation's own arguments (allow SSA values across omp.target?).
2994 Operation *op = getOperation();
2995 if (!isa<TargetOp>(op->getParentOp()) &&
2997 return emitError("expected to be nested inside of omp.target or not nested "
2998 "in any OpenMP dialect operations");
2999
3000 // Check for num_teams clause restrictions
3001 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
3002 this->getNumTeamsUpperVars())))
3003 return failure();
3004
3005 // Check for allocate clause restrictions
3006 if (getAllocateVars().size() != getAllocatorVars().size())
3007 return emitError(
3008 "expected equal sizes for allocate and allocator variables");
3009
3011 op, getDynGroupprivateAccessGroupAttr(),
3012 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
3013 return failure();
3014
3015 if (failed(verifyPrivateVarList(*this)))
3016 return failure();
3017
3018 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3019 getReductionByref());
3020}
3021
3022//===----------------------------------------------------------------------===//
3023// SectionOp
3024//===----------------------------------------------------------------------===//
3025
3026OperandRange SectionOp::getPrivateVars() {
3027 return getParentOp().getPrivateVars();
3028}
3029
3030OperandRange SectionOp::getReductionVars() {
3031 return getParentOp().getReductionVars();
3032}
3033
3034//===----------------------------------------------------------------------===//
3035// SectionsOp
3036//===----------------------------------------------------------------------===//
3037
3038void SectionsOp::build(OpBuilder &builder, OperationState &state,
3039 const SectionsOperands &clauses) {
3040 MLIRContext *ctx = builder.getContext();
3041 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
3042 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3043 clauses.nowait, /*private_vars=*/{},
3044 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
3045 clauses.reductionMod, clauses.reductionVars,
3046 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3047 makeArrayAttr(ctx, clauses.reductionSyms));
3048}
3049
3050LogicalResult SectionsOp::verify() {
3051 if (getAllocateVars().size() != getAllocatorVars().size())
3052 return emitError(
3053 "expected equal sizes for allocate and allocator variables");
3054
3055 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3056 getReductionByref());
3057}
3058
3059LogicalResult SectionsOp::verifyRegions() {
3060 for (auto &inst : *getRegion().begin()) {
3061 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
3062 return emitOpError()
3063 << "expected omp.section op or terminator op inside region";
3064 }
3065 }
3066
3067 return success();
3068}
3069
3070//===----------------------------------------------------------------------===//
3071// ScopeOp
3072//===----------------------------------------------------------------------===//
3073
3074void ScopeOp::build(OpBuilder &builder, OperationState &state,
3075 const ScopeOperands &clauses) {
3076 MLIRContext *ctx = builder.getContext();
3077 ScopeOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3078 clauses.nowait, clauses.privateVars,
3079 makeArrayAttr(ctx, clauses.privateSyms),
3080 clauses.privateNeedsBarrier, clauses.reductionMod,
3081 clauses.reductionVars,
3082 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3083 makeArrayAttr(ctx, clauses.reductionSyms));
3084}
3085
3086LogicalResult ScopeOp::verify() {
3087 if (getAllocateVars().size() != getAllocatorVars().size())
3088 return emitError(
3089 "expected equal sizes for allocate and allocator variables");
3090
3091 if (failed(verifyPrivateVarList(*this)))
3092 return failure();
3093
3094 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3095 getReductionByref());
3096}
3097
3098//===----------------------------------------------------------------------===//
3099// SingleOp
3100//===----------------------------------------------------------------------===//
3101
3102void SingleOp::build(OpBuilder &builder, OperationState &state,
3103 const SingleOperands &clauses) {
3104 MLIRContext *ctx = builder.getContext();
3105 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
3106 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3107 clauses.copyprivateVars,
3108 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
3109 /*private_vars=*/{}, /*private_syms=*/nullptr,
3110 /*private_needs_barrier=*/nullptr);
3111}
3112
3113LogicalResult SingleOp::verify() {
3114 // Check for allocate clause restrictions
3115 if (getAllocateVars().size() != getAllocatorVars().size())
3116 return emitError(
3117 "expected equal sizes for allocate and allocator variables");
3118
3119 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
3120 getCopyprivateSyms());
3121}
3122
3123//===----------------------------------------------------------------------===//
3124// WorkshareOp
3125//===----------------------------------------------------------------------===//
3126
3127void WorkshareOp::build(OpBuilder &builder, OperationState &state,
3128 const WorkshareOperands &clauses) {
3129 WorkshareOp::build(builder, state, clauses.nowait);
3130}
3131
3132//===----------------------------------------------------------------------===//
3133// WorkshareLoopWrapperOp
3134//===----------------------------------------------------------------------===//
3135
3136LogicalResult WorkshareLoopWrapperOp::verify() {
3137 if (!(*this)->getParentOfType<WorkshareOp>())
3138 return emitOpError() << "must be nested in an omp.workshare";
3139 return success();
3140}
3141
3142LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
3143 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3144 getNestedWrapper())
3145 return emitOpError() << "expected to be a standalone loop wrapper";
3146
3147 return success();
3148}
3149
3150//===----------------------------------------------------------------------===//
3151// LoopWrapperInterface
3152//===----------------------------------------------------------------------===//
3153
3154LogicalResult LoopWrapperInterface::verifyImpl() {
3155 Operation *op = this->getOperation();
3156 if (!op->hasTrait<OpTrait::NoTerminator>() ||
3158 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
3159 "and `SingleBlock` traits";
3160
3161 if (op->getNumRegions() != 1)
3162 return emitOpError() << "loop wrapper does not contain exactly one region";
3163
3164 Region &region = op->getRegion(0);
3165 if (range_size(region.getOps()) != 1)
3166 return emitOpError()
3167 << "loop wrapper does not contain exactly one nested op";
3168
3169 Operation &firstOp = *region.op_begin();
3170 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
3171 return emitOpError() << "nested in loop wrapper is not another loop "
3172 "wrapper or `omp.loop_nest`";
3173
3174 return success();
3175}
3176
3177//===----------------------------------------------------------------------===//
3178// LoopOp
3179//===----------------------------------------------------------------------===//
3180
3181void LoopOp::build(OpBuilder &builder, OperationState &state,
3182 const LoopOperands &clauses) {
3183 MLIRContext *ctx = builder.getContext();
3184
3185 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
3186 makeArrayAttr(ctx, clauses.privateSyms),
3187 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
3188 clauses.reductionMod, clauses.reductionVars,
3189 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3190 makeArrayAttr(ctx, clauses.reductionSyms));
3191}
3192
3193LogicalResult LoopOp::verify() {
3194 if (failed(verifyPrivateVarList(*this)))
3195 return failure();
3196
3197 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3198 getReductionByref());
3199}
3200
3201LogicalResult LoopOp::verifyRegions() {
3202 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3203 getNestedWrapper())
3204 return emitOpError() << "expected to be a standalone loop wrapper";
3205
3206 return success();
3207}
3208
3209//===----------------------------------------------------------------------===//
3210// WsloopOp
3211//===----------------------------------------------------------------------===//
3212
3213void WsloopOp::build(OpBuilder &builder, OperationState &state,
3214 ArrayRef<NamedAttribute> attributes) {
3215 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
3216 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
3217 /*linear_var_types*/ nullptr, /*linear_modifiers=*/nullptr,
3218 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
3219 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
3220 /*private_needs_barrier=*/false,
3221 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
3222 /*reduction_byref=*/nullptr,
3223 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
3224 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
3225 /*schedule_simd=*/false);
3226 state.addAttributes(attributes);
3227}
3228
3229void WsloopOp::build(OpBuilder &builder, OperationState &state,
3230 const WsloopOperands &clauses) {
3231 MLIRContext *ctx = builder.getContext();
3232 // TODO: Store clauses in op: allocateVars, allocatorVars
3233 WsloopOp::build(
3234 builder, state,
3235 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
3236 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3237 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3238 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3239 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3240 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3241 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3242 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3243}
3244
3245LogicalResult WsloopOp::verify() {
3246 if (failed(
3247 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3248 return failure();
3249 if (getLinearVars().size() &&
3250 getLinearVarTypes().value().size() != getLinearVars().size())
3251 return emitError() << "Ill-formed type attributes for linear variables";
3252
3253 if (failed(verifyPrivateVarList(*this)))
3254 return failure();
3255
3256 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3257 getReductionByref());
3258}
3259
3260LogicalResult WsloopOp::verifyRegions() {
3261 bool isCompositeChildLeaf =
3262 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3263
3264 if (LoopWrapperInterface nested = getNestedWrapper()) {
3265 if (!isComposite())
3266 return emitError()
3267 << "'omp.composite' attribute missing from composite wrapper";
3268
3269 // Check for the allowed leaf constructs that may appear in a composite
3270 // construct directly after DO/FOR.
3271 if (!isa<SimdOp>(nested))
3272 return emitError() << "only supported nested wrapper is 'omp.simd'";
3273
3274 } else if (isComposite() && !isCompositeChildLeaf) {
3275 return emitError()
3276 << "'omp.composite' attribute present in non-composite wrapper";
3277 } else if (!isComposite() && isCompositeChildLeaf) {
3278 return emitError()
3279 << "'omp.composite' attribute missing from composite wrapper";
3280 }
3281
3282 return success();
3283}
3284
3285//===----------------------------------------------------------------------===//
3286// Simd construct [2.9.3.1]
3287//===----------------------------------------------------------------------===//
3288
3289void SimdOp::build(OpBuilder &builder, OperationState &state,
3290 const SimdOperands &clauses) {
3291 MLIRContext *ctx = builder.getContext();
3292 SimdOp::build(builder, state, clauses.alignedVars,
3293 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
3294 clauses.linearVars, clauses.linearStepVars,
3295 clauses.linearVarTypes, clauses.linearModifiers,
3296 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3297 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3298 clauses.privateNeedsBarrier, clauses.reductionMod,
3299 clauses.reductionVars,
3300 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3301 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
3302 clauses.simdlen);
3303}
3304
3305LogicalResult SimdOp::verify() {
3306 if (getSimdlen().has_value() && getSafelen().has_value() &&
3307 getSimdlen().value() > getSafelen().value())
3308 return emitOpError()
3309 << "simdlen clause and safelen clause are both present, but the "
3310 "simdlen value is not less than or equal to safelen value";
3311
3312 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
3313 return failure();
3314
3315 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
3316 return failure();
3317
3318 if (failed(
3319 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3320 return failure();
3321
3322 bool isCompositeChildLeaf =
3323 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3324
3325 if (!isComposite() && isCompositeChildLeaf)
3326 return emitError()
3327 << "'omp.composite' attribute missing from composite wrapper";
3328
3329 if (isComposite() && !isCompositeChildLeaf)
3330 return emitError()
3331 << "'omp.composite' attribute present in non-composite wrapper";
3332
3333 // Firstprivate is not allowed for SIMD in the standard. Check that none of
3334 // the private decls are for firstprivate.
3335 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3336 if (privateSyms) {
3337 for (const Attribute &sym : *privateSyms) {
3338 auto symRef = cast<SymbolRefAttr>(sym);
3339 omp::PrivateClauseOp privatizer =
3341 getOperation(), symRef);
3342 if (!privatizer)
3343 return emitError() << "Cannot find privatizer '" << symRef << "'";
3344 if (privatizer.getDataSharingType() ==
3345 DataSharingClauseType::FirstPrivate)
3346 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
3347 }
3348 }
3349
3350 if (failed(verifyPrivateVarList(*this)))
3351 return failure();
3352
3353 if (getLinearVars().size() &&
3354 getLinearVarTypes().value().size() != getLinearVars().size())
3355 return emitError() << "Ill-formed type attributes for linear variables";
3356 return success();
3357}
3358
3359LogicalResult SimdOp::verifyRegions() {
3360 if (getNestedWrapper())
3361 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
3362
3363 return success();
3364}
3365
3366//===----------------------------------------------------------------------===//
3367// Distribute construct [2.9.4.1]
3368//===----------------------------------------------------------------------===//
3369
3370void DistributeOp::build(OpBuilder &builder, OperationState &state,
3371 const DistributeOperands &clauses) {
3372 DistributeOp::build(builder, state, clauses.allocateVars,
3373 clauses.allocatorVars, clauses.distScheduleStatic,
3374 clauses.distScheduleChunkSize, clauses.order,
3375 clauses.orderMod, clauses.privateVars,
3376 makeArrayAttr(builder.getContext(), clauses.privateSyms),
3377 clauses.privateNeedsBarrier);
3378}
3379
3380LogicalResult DistributeOp::verify() {
3381 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3382 return emitOpError() << "chunk size set without "
3383 "dist_schedule_static being present";
3384
3385 if (getAllocateVars().size() != getAllocatorVars().size())
3386 return emitError(
3387 "expected equal sizes for allocate and allocator variables");
3388
3389 if (failed(verifyPrivateVarList(*this)))
3390 return failure();
3391
3392 return success();
3393}
3394
3395LogicalResult DistributeOp::verifyRegions() {
3396 if (LoopWrapperInterface nested = getNestedWrapper()) {
3397 if (!isComposite())
3398 return emitError()
3399 << "'omp.composite' attribute missing from composite wrapper";
3400 // Check for the allowed leaf constructs that may appear in a composite
3401 // construct directly after DISTRIBUTE.
3402 if (isa<WsloopOp>(nested)) {
3403 Operation *parentOp = (*this)->getParentOp();
3404 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3405 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3406 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3407 "when a composite 'omp.parallel' is the direct "
3408 "parent";
3409 }
3410 } else if (!isa<SimdOp>(nested))
3411 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3412 "'omp.wsloop'";
3413 } else if (isComposite()) {
3414 return emitError()
3415 << "'omp.composite' attribute present in non-composite wrapper";
3416 }
3417
3418 return success();
3419}
3420
3421//===----------------------------------------------------------------------===//
3422// DeclareMapperOp / DeclareMapperInfoOp
3423//===----------------------------------------------------------------------===//
3424
3425LogicalResult DeclareMapperInfoOp::verify() {
3426 return verifyMapClause(*this, getMapVars());
3427}
3428
3429LogicalResult DeclareMapperOp::verifyRegions() {
3430 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3431 getRegion().getBlocks().front().getTerminator()))
3432 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3433
3434 return success();
3435}
3436
3437//===----------------------------------------------------------------------===//
3438// DeclareReductionOp
3439//===----------------------------------------------------------------------===//
3440
3441LogicalResult DeclareReductionOp::verifyRegions() {
3442 if (!getAllocRegion().empty()) {
3443 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3444 if (yieldOp.getResults().size() != 1 ||
3445 yieldOp.getResults().getTypes()[0] != getType())
3446 return emitOpError() << "expects alloc region to yield a value "
3447 "of the reduction type";
3448 }
3449 }
3450
3451 if (getInitializerRegion().empty())
3452 return emitOpError() << "expects non-empty initializer region";
3453 Block &initializerEntryBlock = getInitializerRegion().front();
3454
3455 if (initializerEntryBlock.getNumArguments() == 1) {
3456 if (!getAllocRegion().empty())
3457 return emitOpError() << "expects two arguments to the initializer region "
3458 "when an allocation region is used";
3459 } else if (initializerEntryBlock.getNumArguments() == 2) {
3460 if (getAllocRegion().empty())
3461 return emitOpError() << "expects one argument to the initializer region "
3462 "when no allocation region is used";
3463 } else {
3464 return emitOpError()
3465 << "expects one or two arguments to the initializer region";
3466 }
3467
3468 for (mlir::Value arg : initializerEntryBlock.getArguments())
3469 if (arg.getType() != getType())
3470 return emitOpError() << "expects initializer region argument to match "
3471 "the reduction type";
3472
3473 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3474 if (yieldOp.getResults().size() != 1 ||
3475 yieldOp.getResults().getTypes()[0] != getType())
3476 return emitOpError() << "expects initializer region to yield a value "
3477 "of the reduction type";
3478 }
3479
3480 if (getReductionRegion().empty())
3481 return emitOpError() << "expects non-empty reduction region";
3482 Block &reductionEntryBlock = getReductionRegion().front();
3483 if (reductionEntryBlock.getNumArguments() != 2 ||
3484 reductionEntryBlock.getArgumentTypes()[0] !=
3485 reductionEntryBlock.getArgumentTypes()[1] ||
3486 reductionEntryBlock.getArgumentTypes()[0] != getType())
3487 return emitOpError() << "expects reduction region with two arguments of "
3488 "the reduction type";
3489 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3490 if (yieldOp.getResults().size() != 1 ||
3491 yieldOp.getResults().getTypes()[0] != getType())
3492 return emitOpError() << "expects reduction region to yield a value "
3493 "of the reduction type";
3494 }
3495
3496 if (!getAtomicReductionRegion().empty()) {
3497 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3498 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3499 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3500 atomicReductionEntryBlock.getArgumentTypes()[1])
3501 return emitOpError() << "expects atomic reduction region with two "
3502 "arguments of the same type";
3503 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3504 atomicReductionEntryBlock.getArgumentTypes()[0]);
3505 if (!ptrType ||
3506 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3507 return emitOpError() << "expects atomic reduction region arguments to "
3508 "be accumulators containing the reduction type";
3509 }
3510
3511 if (getCleanupRegion().empty())
3512 return success();
3513 Block &cleanupEntryBlock = getCleanupRegion().front();
3514 if (cleanupEntryBlock.getNumArguments() != 1 ||
3515 cleanupEntryBlock.getArgument(0).getType() != getType())
3516 return emitOpError() << "expects cleanup region with one argument "
3517 "of the reduction type";
3518
3519 return success();
3520}
3521
3522//===----------------------------------------------------------------------===//
3523// TaskOp
3524//===----------------------------------------------------------------------===//
3525
3526void TaskOp::build(OpBuilder &builder, OperationState &state,
3527 const TaskOperands &clauses) {
3528 MLIRContext *ctx = builder.getContext();
3529 TaskOp::build(
3530 builder, state, clauses.iterated, clauses.affinityVars,
3531 clauses.allocateVars, clauses.allocatorVars,
3532 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3533 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3534 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3535 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3536 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3537 clauses.priority, /*private_vars=*/clauses.privateVars,
3538 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3539 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3540}
3541
3542LogicalResult TaskOp::verify() {
3543 LogicalResult verifyDependVars =
3544 verifyDependVarList(*this, getDependKinds(), getDependVars(),
3545 getDependIteratedKinds(), getDependIterated());
3546 if (failed(verifyDependVars))
3547 return verifyDependVars;
3548
3549 if (failed(verifyPrivateVarList(*this)))
3550 return failure();
3551
3552 return verifyReductionVarList(*this, getInReductionSyms(),
3553 getInReductionVars(), getInReductionByref());
3554}
3555
3556//===----------------------------------------------------------------------===//
3557// TaskgroupOp
3558//===----------------------------------------------------------------------===//
3559
3560void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3561 const TaskgroupOperands &clauses) {
3562 MLIRContext *ctx = builder.getContext();
3563 TaskgroupOp::build(builder, state, clauses.allocateVars,
3564 clauses.allocatorVars, clauses.taskReductionVars,
3565 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3566 makeArrayAttr(ctx, clauses.taskReductionSyms));
3567}
3568
3569LogicalResult TaskgroupOp::verify() {
3570 return verifyReductionVarList(*this, getTaskReductionSyms(),
3571 getTaskReductionVars(),
3572 getTaskReductionByref());
3573}
3574
3575//===----------------------------------------------------------------------===//
3576// TaskloopContextOp
3577//===----------------------------------------------------------------------===//
3578
3579void TaskloopContextOp::build(OpBuilder &builder, OperationState &state,
3580 const TaskloopContextOperands &clauses) {
3581 MLIRContext *ctx = builder.getContext();
3582 TaskloopContextOp::build(
3583 builder, state, clauses.allocateVars, clauses.allocatorVars,
3584 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3585 clauses.inReductionVars,
3586 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3587 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3588 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3589 /*private_vars=*/clauses.privateVars,
3590 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3591 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3592 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3593 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3594}
3595
3596TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3597 return cast<TaskloopWrapperOp>(
3598 *llvm::find_if(getRegion().front(), [](mlir::Operation &op) {
3599 return isa<TaskloopWrapperOp>(op);
3600 }));
3601}
3602
3603LogicalResult TaskloopContextOp::verify() {
3604 if (getAllocateVars().size() != getAllocatorVars().size())
3605 return emitError(
3606 "expected equal sizes for allocate and allocator variables");
3607
3608 if (failed(verifyPrivateVarList(*this)))
3609 return failure();
3610
3611 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3612 getReductionVars(), getReductionByref())) ||
3613 failed(verifyReductionVarList(*this, getInReductionSyms(),
3614 getInReductionVars(),
3615 getInReductionByref())))
3616 return failure();
3617
3618 if (!getReductionVars().empty() && getNogroup())
3619 return emitError("if a reduction clause is present on the taskloop "
3620 "directive, the nogroup clause must not be specified");
3621 for (auto var : getReductionVars()) {
3622 if (llvm::is_contained(getInReductionVars(), var))
3623 return emitError("the same list item cannot appear in both a reduction "
3624 "and an in_reduction clause");
3625 }
3626
3627 if (getGrainsize() && getNumTasks()) {
3628 return emitError(
3629 "the grainsize clause and num_tasks clause are mutually exclusive and "
3630 "may not appear on the same taskloop directive");
3631 }
3632
3633 return success();
3634}
3635
3636LogicalResult TaskloopContextOp::verifyRegions() {
3637 Region &region = getRegion();
3638 if (region.empty())
3639 return emitOpError() << "expected non-empty region";
3640
3641 auto count = llvm::count_if(region.front(), [](mlir::Operation &op) {
3642 return isa<TaskloopWrapperOp>(op);
3643 });
3644 if (count != 1)
3645 return emitOpError()
3646 << "expected exactly 1 TaskloopWrapperOp directly nested in "
3647 "the region, but "
3648 << count << " were found";
3649 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3650
3651 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3652 // This will fail the verifier for TaskloopWrapperOp and print an error
3653 // message there.
3654 if (!loopNestOp)
3655 return failure();
3656
3657 std::function<bool(Value)> isValidBoundValue = [&](Value value) -> bool {
3658 Region *valueRegion = value.getParentRegion();
3659 // A loop bound value defined outside of the taskloop context region is
3660 // valid. A region is considered an ancestor of itself.
3661 if (!region.isAncestor(valueRegion))
3662 return true;
3663
3664 Operation *defOp = value.getDefiningOp();
3665 if (!defOp || defOp->getNumRegions() != 0 || !isPure(defOp))
3666 return false;
3667
3668 return llvm::all_of(defOp->getOperands(), isValidBoundValue);
3669 };
3670 auto hasUnsupportedTaskloopLocalBound = [&](OperandRange range) -> bool {
3671 return llvm::any_of(range,
3672 [&](Value value) { return !isValidBoundValue(value); });
3673 };
3674
3675 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3676 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3677 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3678 return emitOpError()
3679 << "expects loop bounds and steps to be defined outside of the "
3680 "taskloop.context region or by pure, regionless operations "
3681 "that do not depend on block arguments";
3682 }
3683
3684 return success();
3685}
3686
3687//===----------------------------------------------------------------------===//
3688// TaskloopWrapperOp
3689//===----------------------------------------------------------------------===//
3690
3691void TaskloopWrapperOp::build(OpBuilder &builder, OperationState &state,
3692 const TaskloopWrapperOperands &clauses) {
3693 TaskloopWrapperOp::build(builder, state);
3694}
3695
3696TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3697 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3698}
3699
3700LogicalResult TaskloopWrapperOp::verify() {
3701 TaskloopContextOp context = getTaskloopContext();
3702 if (!context)
3703 return emitOpError() << "expected to be nested in a taskloop context op";
3704 return success();
3705}
3706
3707LogicalResult TaskloopWrapperOp::verifyRegions() {
3708 if (LoopWrapperInterface nested = getNestedWrapper()) {
3709 if (!isComposite())
3710 return emitError()
3711 << "'omp.composite' attribute missing from composite wrapper";
3712
3713 // Check for the allowed leaf constructs that may appear in a composite
3714 // construct directly after TASKLOOP.
3715 if (!isa<SimdOp>(nested))
3716 return emitError() << "only supported nested wrapper is 'omp.simd'";
3717 } else if (isComposite()) {
3718 return emitError()
3719 << "'omp.composite' attribute present in non-composite wrapper";
3720 }
3721
3722 return success();
3723}
3724
3725//===----------------------------------------------------------------------===//
3726// LoopNestOp
3727//===----------------------------------------------------------------------===//
3728
3729ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3730 // Parse an opening `(` followed by induction variables followed by `)`
3733 Type loopVarType;
3735 parser.parseColonType(loopVarType) ||
3736 // Parse loop bounds.
3737 parser.parseEqual() ||
3738 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3739 parser.parseKeyword("to") ||
3740 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3741 return failure();
3742
3743 for (auto &iv : ivs)
3744 iv.type = loopVarType;
3745
3746 auto *ctx = parser.getBuilder().getContext();
3747 // Parse "inclusive" flag.
3748 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3749 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3750
3751 // Parse step values.
3753 if (parser.parseKeyword("step") ||
3754 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3755 return failure();
3756
3757 // Parse collapse
3758 int64_t value = 0;
3759 if (!parser.parseOptionalKeyword("collapse") &&
3760 (parser.parseLParen() || parser.parseInteger(value) ||
3761 parser.parseRParen()))
3762 return failure();
3763 if (value > 1)
3764 result.addAttribute(
3765 "collapse_num_loops",
3766 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3767
3768 // Parse tiles
3770 auto parseTiles = [&]() -> ParseResult {
3771 int64_t tile;
3772 if (parser.parseInteger(tile))
3773 return failure();
3774 tiles.push_back(tile);
3775 return success();
3776 };
3777
3778 if (!parser.parseOptionalKeyword("tiles") &&
3779 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3780 parser.parseRParen()))
3781 return failure();
3782
3783 if (tiles.size() > 0)
3784 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3785
3786 // Parse the body.
3787 Region *region = result.addRegion();
3788 if (parser.parseRegion(*region, ivs))
3789 return failure();
3790
3791 // Resolve operands.
3792 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3793 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3794 parser.resolveOperands(steps, loopVarType, result.operands))
3795 return failure();
3796
3797 // Parse the optional attribute list.
3798 return parser.parseOptionalAttrDict(result.attributes);
3799}
3800
3801void LoopNestOp::print(OpAsmPrinter &p) {
3802 Region &region = getRegion();
3803 auto args = region.getArguments();
3804 p << " (" << args << ") : " << args[0].getType() << " = ("
3805 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3806 if (getLoopInclusive())
3807 p << "inclusive ";
3808 p << "step (" << getLoopSteps() << ") ";
3809 if (int64_t numCollapse = getCollapseNumLoops())
3810 if (numCollapse > 1)
3811 p << "collapse(" << numCollapse << ") ";
3812
3813 if (const auto tiles = getTileSizes())
3814 p << "tiles(" << tiles.value() << ") ";
3815
3816 p.printRegion(region, /*printEntryBlockArgs=*/false);
3817}
3818
3819void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3820 const LoopNestOperands &clauses) {
3821 MLIRContext *ctx = builder.getContext();
3822 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3823 clauses.loopLowerBounds, clauses.loopUpperBounds,
3824 clauses.loopSteps, clauses.loopInclusive,
3825 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3826}
3827
3828LogicalResult LoopNestOp::verify() {
3829 if (getLoopLowerBounds().empty())
3830 return emitOpError() << "must represent at least one loop";
3831
3832 if (getLoopLowerBounds().size() != getIVs().size())
3833 return emitOpError() << "number of range arguments and IVs do not match";
3834
3835 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3836 if (lb.getType() != iv.getType())
3837 return emitOpError()
3838 << "range argument type does not match corresponding IV type";
3839 }
3840
3841 uint64_t numIVs = getIVs().size();
3842
3843 if (const auto &numCollapse = getCollapseNumLoops())
3844 if (numCollapse > numIVs)
3845 return emitOpError()
3846 << "collapse value is larger than the number of loops";
3847
3848 if (const auto &tiles = getTileSizes())
3849 if (tiles.value().size() > numIVs)
3850 return emitOpError() << "too few canonical loops for tile dimensions";
3851
3852 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3853 return emitOpError() << "expects parent op to be a loop wrapper";
3854
3855 return success();
3856}
3857
3858void LoopNestOp::gatherWrappers(
3860 Operation *parent = (*this)->getParentOp();
3861 while (auto wrapper =
3862 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3863 wrappers.push_back(wrapper);
3864 parent = parent->getParentOp();
3865 }
3866}
3867
3868//===----------------------------------------------------------------------===//
3869// OpenMP canonical loop handling
3870//===----------------------------------------------------------------------===//
3871
3872std::tuple<NewCliOp, OpOperand *, OpOperand *>
3873mlir::omp ::decodeCli(Value cli) {
3874
3875 // Defining a CLI for a generated loop is optional; if there is none then
3876 // there is no followup-tranformation
3877 if (!cli)
3878 return {{}, nullptr, nullptr};
3879
3880 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3881 "Unexpected type of cli");
3882
3883 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3884 OpOperand *gen = nullptr;
3885 OpOperand *cons = nullptr;
3886 for (OpOperand &use : cli.getUses()) {
3887 auto op = cast<LoopTransformationInterface>(use.getOwner());
3888
3889 unsigned opnum = use.getOperandNumber();
3890 if (op.isGeneratee(opnum)) {
3891 assert(!gen && "Each CLI may have at most one def");
3892 gen = &use;
3893 } else if (op.isApplyee(opnum)) {
3894 assert(!cons && "Each CLI may have at most one consumer");
3895 cons = &use;
3896 } else {
3897 llvm_unreachable("Unexpected operand for a CLI");
3898 }
3899 }
3900
3901 return {create, gen, cons};
3902}
3903
3904void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3905 ::mlir::OperationState &odsState) {
3906 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3907}
3908
3909void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3910 Value result = getResult();
3911 auto [newCli, gen, cons] = decodeCli(result);
3912
3913 // Structured binding `gen` cannot be captured in lambdas before C++20
3914 OpOperand *generator = gen;
3915
3916 // Derive the CLI variable name from its generator:
3917 // * "canonloop" for omp.canonical_loop
3918 // * custom name for loop transformation generatees
3919 // * "cli" as fallback if no generator
3920 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3921 // at that level
3922 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3923 // the index of that region
3924 std::string cliName{"cli"};
3925 if (gen) {
3926 cliName =
3928 .Case([&](CanonicalLoopOp op) {
3929 return generateLoopNestingName("canonloop", op);
3930 })
3931 .Case([&](UnrollHeuristicOp op) -> std::string {
3932 llvm_unreachable("heuristic unrolling does not generate a loop");
3933 })
3934 .Case([&](FuseOp op) -> std::string {
3935 unsigned opnum = generator->getOperandNumber();
3936 // The position of the first loop to be fused is the same position
3937 // as the resulting fused loop
3938 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3939 return "canonloop_fuse";
3940 else
3941 return "fused";
3942 })
3943 .Case([&](TileOp op) -> std::string {
3944 auto [generateesFirst, generateesCount] =
3945 op.getGenerateesODSOperandIndexAndLength();
3946 unsigned firstGrid = generateesFirst;
3947 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3948 unsigned end = generateesFirst + generateesCount;
3949 unsigned opnum = generator->getOperandNumber();
3950 // In the OpenMP apply and looprange clauses, indices are 1-based
3951 if (firstGrid <= opnum && opnum < firstIntratile) {
3952 unsigned gridnum = opnum - firstGrid + 1;
3953 return ("grid" + Twine(gridnum)).str();
3954 }
3955 if (firstIntratile <= opnum && opnum < end) {
3956 unsigned intratilenum = opnum - firstIntratile + 1;
3957 return ("intratile" + Twine(intratilenum)).str();
3958 }
3959 llvm_unreachable("Unexpected generatee argument");
3960 })
3961 .DefaultUnreachable("TODO: Custom name for this operation");
3962 }
3963
3964 setNameFn(result, cliName);
3965}
3966
3967LogicalResult NewCliOp::verify() {
3968 Value cli = getResult();
3969
3970 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3971 "Unexpected type of cli");
3972
3973 // Check that the CLI is used in at most generator and one consumer
3974 OpOperand *gen = nullptr;
3975 OpOperand *cons = nullptr;
3976 for (mlir::OpOperand &use : cli.getUses()) {
3977 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3978
3979 unsigned opnum = use.getOperandNumber();
3980 if (op.isGeneratee(opnum)) {
3981 if (gen) {
3982 InFlightDiagnostic error =
3983 emitOpError("CLI must have at most one generator");
3984 error.attachNote(gen->getOwner()->getLoc())
3985 .append("first generator here:");
3986 error.attachNote(use.getOwner()->getLoc())
3987 .append("second generator here:");
3988 return error;
3989 }
3990
3991 gen = &use;
3992 } else if (op.isApplyee(opnum)) {
3993 if (cons) {
3994 InFlightDiagnostic error =
3995 emitOpError("CLI must have at most one consumer");
3996 error.attachNote(cons->getOwner()->getLoc())
3997 .append("first consumer here:")
3998 .appendOp(*cons->getOwner(),
3999 OpPrintingFlags().printGenericOpForm());
4000 error.attachNote(use.getOwner()->getLoc())
4001 .append("second consumer here:")
4002 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
4003 return error;
4004 }
4005
4006 cons = &use;
4007 } else {
4008 llvm_unreachable("Unexpected operand for a CLI");
4009 }
4010 }
4011
4012 // If the CLI is source of a transformation, it must have a generator
4013 if (cons && !gen) {
4014 InFlightDiagnostic error = emitOpError("CLI has no generator");
4015 error.attachNote(cons->getOwner()->getLoc())
4016 .append("see consumer here: ")
4017 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
4018 return error;
4019 }
4020
4021 return success();
4022}
4023
4024void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4025 Value tripCount) {
4026 odsState.addOperands(tripCount);
4027 odsState.addOperands(Value());
4028 (void)odsState.addRegion();
4029}
4030
4031void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4032 Value tripCount, ::mlir::Value cli) {
4033 odsState.addOperands(tripCount);
4034 odsState.addOperands(cli);
4035 (void)odsState.addRegion();
4036}
4037
4038void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
4039 setNameFn(&getRegion().front(), "body_entry");
4040}
4041
4042void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
4043 OpAsmSetValueNameFn setNameFn) {
4044 std::string ivName = generateLoopNestingName("iv", *this);
4045 setNameFn(region.getArgument(0), ivName);
4046}
4047
4048void CanonicalLoopOp::print(OpAsmPrinter &p) {
4049 if (getCli())
4050 p << '(' << getCli() << ')';
4051 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
4052 << " in range(" << getTripCount() << ") ";
4053
4054 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4055 /*printBlockTerminators=*/true);
4056
4057 p.printOptionalAttrDict((*this)->getAttrs());
4058}
4059
4060mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
4062 CanonicalLoopInfoType cliType =
4063 CanonicalLoopInfoType::get(parser.getContext());
4064
4065 // Parse (optional) omp.cli identifier
4067 SmallVector<mlir::Value, 1> cliOperand;
4068 if (!parser.parseOptionalLParen()) {
4069 if (parser.parseOperand(cli) ||
4070 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
4071 return failure();
4072 }
4073
4074 // We derive the type of tripCount from inductionVariable. MLIR requires the
4075 // type of tripCount to be known when calling resolveOperand so we have parse
4076 // the type before processing the inductionVariable.
4077 OpAsmParser::Argument inductionVariable;
4079 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
4080 parser.parseKeyword("in") || parser.parseKeyword("range") ||
4081 parser.parseLParen() || parser.parseOperand(tripcount) ||
4082 parser.parseRParen() ||
4083 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
4084 return failure();
4085
4086 // Parse the loop body.
4087 Region *region = result.addRegion();
4088 if (parser.parseRegion(*region, {inductionVariable}))
4089 return failure();
4090
4091 // We parsed the cli operand forst, but because it is optional, it must be
4092 // last in the operand list.
4093 result.operands.append(cliOperand);
4094
4095 // Parse the optional attribute list.
4096 if (parser.parseOptionalAttrDict(result.attributes))
4097 return failure();
4098
4099 return mlir::success();
4100}
4101
4102LogicalResult CanonicalLoopOp::verify() {
4103 // The region's entry must accept the induction variable
4104 // It can also be empty if just created
4105 if (!getRegion().empty()) {
4106 Region &region = getRegion();
4107 if (region.getNumArguments() != 1)
4108 return emitOpError(
4109 "Canonical loop region must have exactly one argument");
4110
4111 if (getInductionVar().getType() != getTripCount().getType())
4112 return emitOpError(
4113 "Region argument must be the same type as the trip count");
4114 }
4115
4116 return success();
4117}
4118
4119Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
4120
4121std::pair<unsigned, unsigned>
4122CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
4123 // No applyees
4124 return {0, 0};
4125}
4126
4127std::pair<unsigned, unsigned>
4128CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
4129 return getODSOperandIndexAndLength(odsIndex_cli);
4130}
4131
4132//===----------------------------------------------------------------------===//
4133// UnrollHeuristicOp
4134//===----------------------------------------------------------------------===//
4135
4136void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
4137 ::mlir::OperationState &odsState,
4138 ::mlir::Value cli) {
4139 odsState.addOperands(cli);
4140}
4141
4142void UnrollHeuristicOp::print(OpAsmPrinter &p) {
4143 p << '(' << getApplyee() << ')';
4144
4145 p.printOptionalAttrDict((*this)->getAttrs());
4146}
4147
4148mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
4150 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
4151
4152 if (parser.parseLParen())
4153 return failure();
4154
4156 if (parser.parseOperand(applyee) ||
4157 parser.resolveOperand(applyee, cliType, result.operands))
4158 return failure();
4159
4160 if (parser.parseRParen())
4161 return failure();
4162
4163 // Optional output loop (full unrolling has none)
4164 if (!parser.parseOptionalArrow()) {
4165 if (parser.parseLParen() || parser.parseRParen())
4166 return failure();
4167 }
4168
4169 // Parse the optional attribute list.
4170 if (parser.parseOptionalAttrDict(result.attributes))
4171 return failure();
4172
4173 return mlir::success();
4174}
4175
4176std::pair<unsigned, unsigned>
4177UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
4178 return getODSOperandIndexAndLength(odsIndex_applyee);
4179}
4180
4181std::pair<unsigned, unsigned>
4182UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
4183 return {0, 0};
4184}
4185
4186//===----------------------------------------------------------------------===//
4187// TileOp
4188//===----------------------------------------------------------------------===//
4189
4190static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
4191 OperandRange generatees,
4192 OperandRange applyees) {
4193 if (!generatees.empty())
4194 p << '(' << llvm::interleaved(generatees) << ')';
4195
4196 if (!applyees.empty())
4197 p << " <- (" << llvm::interleaved(applyees) << ')';
4198}
4199
4200static ParseResult parseLoopTransformClis(
4201 OpAsmParser &parser,
4204 if (parser.parseOptionalLess()) {
4205 // Syntax 1: generatees present
4206
4207 if (parser.parseOperandList(generateesOperands,
4209 return failure();
4210
4211 if (parser.parseLess())
4212 return failure();
4213 } else {
4214 // Syntax 2: generatees omitted
4215 }
4216
4217 // Parse `<-` (`<` has already been parsed)
4218 if (parser.parseMinus())
4219 return failure();
4220
4221 if (parser.parseOperandList(applyeesOperands,
4223 return failure();
4224
4225 return success();
4226}
4227
4228/// Check properties of the loop nest consisting of the transformation's
4229/// applyees:
4230/// 1. They are nested inside each other
4231/// 2. They are perfectly nested
4232/// (no code with side-effects in-between the loops)
4233/// 3. They are rectangular
4234/// (loop bounds are invariant in respect to the outer loops)
4235///
4236/// TODO: Generalize for LoopTransformationInterface.
4237static LogicalResult checkApplyeesNesting(TileOp op) {
4238 // Collect the loops from the nest
4239 bool isOnlyCanonLoops = true;
4241 for (Value applyee : op.getApplyees()) {
4242 auto [create, gen, cons] = decodeCli(applyee);
4243
4244 if (!gen)
4245 return op.emitOpError() << "applyee CLI has no generator";
4246
4247 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4248 canonLoops.push_back(loop);
4249 if (!loop)
4250 isOnlyCanonLoops = false;
4251 }
4252
4253 // FIXME: We currently can only verify non-rectangularity and perfect nest of
4254 // omp.canonical_loop.
4255 if (!isOnlyCanonLoops)
4256 return success();
4257
4258 DenseSet<Value> parentIVs;
4259 for (auto i : llvm::seq<int>(1, canonLoops.size())) {
4260 auto parentLoop = canonLoops[i - 1];
4261 auto loop = canonLoops[i];
4262
4263 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4264 return op.emitOpError()
4265 << "tiled loop nest must be nested within each other";
4266
4267 parentIVs.insert(parentLoop.getInductionVar());
4268
4269 // Canonical loop must be perfectly nested, i.e. the body of the parent must
4270 // only contain the omp.canonical_loop of the nested loops, and
4271 // omp.terminator
4272 bool isPerfectlyNested = [&]() {
4273 auto &parentBody = parentLoop.getRegion();
4274 if (!parentBody.hasOneBlock())
4275 return false;
4276 auto &parentBlock = parentBody.getBlocks().front();
4277
4278 auto nestedLoopIt = parentBlock.begin();
4279 if (nestedLoopIt == parentBlock.end() ||
4280 (&*nestedLoopIt != loop.getOperation()))
4281 return false;
4282
4283 auto termIt = std::next(nestedLoopIt);
4284 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4285 return false;
4286
4287 if (std::next(termIt) != parentBlock.end())
4288 return false;
4289
4290 return true;
4291 }();
4292 if (!isPerfectlyNested)
4293 return op.emitOpError() << "tiled loop nest must be perfectly nested";
4294
4295 if (parentIVs.contains(loop.getTripCount()))
4296 return op.emitOpError() << "tiled loop nest must be rectangular";
4297 }
4298
4299 // TODO: The tile sizes must be computed before the loop, but checking this
4300 // requires dominance analysis. For instance:
4301 //
4302 // %canonloop = omp.new_cli
4303 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
4304 // // write to %x
4305 // omp.terminator
4306 // }
4307 // %ts = llvm.load %x
4308 // omp.tile <- (%canonloop) sizes(%ts : i32)
4309
4310 return success();
4311}
4312
4313LogicalResult TileOp::verify() {
4314 if (getApplyees().empty())
4315 return emitOpError() << "must apply to at least one loop";
4316
4317 if (getSizes().size() != getApplyees().size())
4318 return emitOpError() << "there must be one tile size for each applyee";
4319
4320 if (!getGeneratees().empty() &&
4321 2 * getSizes().size() != getGeneratees().size())
4322 return emitOpError()
4323 << "expecting two times the number of generatees than applyees";
4324
4325 return checkApplyeesNesting(*this);
4326}
4327
4328std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4329 return getODSOperandIndexAndLength(odsIndex_applyees);
4330}
4331
4332std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4333 return getODSOperandIndexAndLength(odsIndex_generatees);
4334}
4335
4336//===----------------------------------------------------------------------===//
4337// FuseOp
4338//===----------------------------------------------------------------------===//
4339
4340static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
4341 OperandRange generatees,
4342 OperandRange applyees) {
4343 if (!generatees.empty())
4344 p << '(' << llvm::interleaved(generatees) << ')';
4345
4346 if (!applyees.empty())
4347 p << " <- (" << llvm::interleaved(applyees) << ')';
4348}
4349
4350LogicalResult FuseOp::verify() {
4351 if (getApplyees().size() < 2)
4352 return emitOpError() << "must apply to at least two loops";
4353
4354 if (getFirst().has_value() && getCount().has_value()) {
4355 int64_t first = getFirst().value();
4356 int64_t count = getCount().value();
4357 if ((unsigned)(first + count - 1) > getApplyees().size())
4358 return emitOpError() << "the numbers of applyees must be at least first "
4359 "minus one plus count attributes";
4360 if (!getGeneratees().empty() &&
4361 getGeneratees().size() != getApplyees().size() + 1 - count)
4362 return emitOpError() << "the number of generatees must be the number of "
4363 "aplyees plus one minus count";
4364
4365 } else {
4366 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4367 return emitOpError()
4368 << "in a complete fuse the number of generatees must be exactly 1";
4369 }
4370 for (auto &&applyee : getApplyees()) {
4371 auto [create, gen, cons] = decodeCli(applyee);
4372
4373 if (!gen)
4374 return emitOpError() << "applyee CLI has no generator";
4375 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4376 if (!loop)
4377 return emitOpError()
4378 << "currently only supports omp.canonical_loop as applyee";
4379 }
4380 return success();
4381}
4382std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4383 return getODSOperandIndexAndLength(odsIndex_applyees);
4384}
4385
4386std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4387 return getODSOperandIndexAndLength(odsIndex_generatees);
4388}
4389
4390//===----------------------------------------------------------------------===//
4391// Critical construct (2.17.1)
4392//===----------------------------------------------------------------------===//
4393
4394void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
4395 const CriticalDeclareOperands &clauses) {
4396 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4397}
4398
4399LogicalResult CriticalDeclareOp::verify() {
4400 return verifySynchronizationHint(*this, getHint());
4401}
4402
4403LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4404 if (getNameAttr()) {
4405 SymbolRefAttr symbolRef = getNameAttr();
4406 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
4407 *this, symbolRef);
4408 if (!decl) {
4409 return emitOpError() << "expected symbol reference " << symbolRef
4410 << " to point to a critical declaration";
4411 }
4412 }
4413
4414 return success();
4415}
4416
4417//===----------------------------------------------------------------------===//
4418// Ordered construct
4419//===----------------------------------------------------------------------===//
4420
4421static LogicalResult verifyOrderedParent(Operation &op) {
4422 bool hasRegion = op.getNumRegions() > 0;
4423 auto loopOp = op.getParentOfType<LoopNestOp>();
4424 if (!loopOp) {
4425 if (hasRegion)
4426 return success();
4427
4428 // TODO: Consider if this needs to be the case only for the standalone
4429 // variant of the ordered construct.
4430 return op.emitOpError() << "must be nested inside of a loop";
4431 }
4432
4433 Operation *wrapper = loopOp->getParentOp();
4434 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4435 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4436 if (!orderedAttr)
4437 return op.emitOpError() << "the enclosing worksharing-loop region must "
4438 "have an ordered clause";
4439
4440 if (hasRegion && orderedAttr.getInt() != 0)
4441 return op.emitOpError() << "the enclosing loop's ordered clause must not "
4442 "have a parameter present";
4443
4444 if (!hasRegion && orderedAttr.getInt() == 0)
4445 return op.emitOpError() << "the enclosing loop's ordered clause must "
4446 "have a parameter present";
4447 } else if (!isa<SimdOp>(wrapper)) {
4448 return op.emitOpError() << "must be nested inside of a worksharing, simd "
4449 "or worksharing simd loop";
4450 }
4451 return success();
4452}
4453
4454void OrderedOp::build(OpBuilder &builder, OperationState &state,
4455 const OrderedOperands &clauses) {
4456 OrderedOp::build(builder, state, clauses.doacrossDependType,
4457 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4458}
4459
4460LogicalResult OrderedOp::verify() {
4461 if (failed(verifyOrderedParent(**this)))
4462 return failure();
4463
4464 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4465 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4466 return emitOpError() << "number of variables in depend clause does not "
4467 << "match number of iteration variables in the "
4468 << "doacross loop";
4469
4470 return success();
4471}
4472
4473void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
4474 const OrderedRegionOperands &clauses) {
4475 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4476}
4477
4478LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
4479
4480//===----------------------------------------------------------------------===//
4481// TaskwaitOp
4482//===----------------------------------------------------------------------===//
4483
4484void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
4485 const TaskwaitOperands &clauses) {
4486 // TODO Store clauses in op: dependKinds, dependVars, nowait.
4487 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
4488 /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr,
4489 /*depend_iterated=*/{}, /*nowait=*/nullptr);
4490}
4491
4492//===----------------------------------------------------------------------===//
4493// Verifier for AtomicReadOp
4494//===----------------------------------------------------------------------===//
4495
4496LogicalResult AtomicReadOp::verify() {
4497 if (verifyCommon().failed())
4498 return mlir::failure();
4499
4500 if (auto mo = getMemoryOrder()) {
4501 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4502 *mo == ClauseMemoryOrderKind::Release) {
4503 return emitError(
4504 "memory-order must not be acq_rel or release for atomic reads");
4505 }
4506 }
4507 return verifySynchronizationHint(*this, getHint());
4508}
4509
4510//===----------------------------------------------------------------------===//
4511// Verifier for AtomicWriteOp
4512//===----------------------------------------------------------------------===//
4513
4514LogicalResult AtomicWriteOp::verify() {
4515 if (verifyCommon().failed())
4516 return mlir::failure();
4517
4518 if (auto mo = getMemoryOrder()) {
4519 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4520 *mo == ClauseMemoryOrderKind::Acquire) {
4521 return emitError(
4522 "memory-order must not be acq_rel or acquire for atomic writes");
4523 }
4524 }
4525 return verifySynchronizationHint(*this, getHint());
4526}
4527
4528//===----------------------------------------------------------------------===//
4529// Verifier for AtomicUpdateOp
4530//===----------------------------------------------------------------------===//
4531
4532LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4533 PatternRewriter &rewriter) {
4534 if (op.isNoOp()) {
4535 rewriter.eraseOp(op);
4536 return success();
4537 }
4538 if (Value writeVal = op.getWriteOpVal()) {
4539 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4540 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4541 return success();
4542 }
4543 return failure();
4544}
4545
4546LogicalResult AtomicUpdateOp::verify() {
4547 if (verifyCommon().failed())
4548 return mlir::failure();
4549
4550 if (auto mo = getMemoryOrder()) {
4551 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4552 *mo == ClauseMemoryOrderKind::Acquire) {
4553 return emitError(
4554 "memory-order must not be acq_rel or acquire for atomic updates");
4555 }
4556 }
4557
4558 return verifySynchronizationHint(*this, getHint());
4559}
4560
4561LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4562
4563//===----------------------------------------------------------------------===//
4564// Verifier for AtomicCaptureOp
4565//===----------------------------------------------------------------------===//
4566
4567AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4568 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4569 return op;
4570 return dyn_cast<AtomicReadOp>(getSecondOp());
4571}
4572
4573AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4574 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4575 return op;
4576 return dyn_cast<AtomicWriteOp>(getSecondOp());
4577}
4578
4579AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4580 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4581 return op;
4582 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4583}
4584
4585LogicalResult AtomicCaptureOp::verify() {
4586 return verifySynchronizationHint(*this, getHint());
4587}
4588
4589LogicalResult AtomicCaptureOp::verifyRegions() {
4590 if (verifyRegionsCommon().failed())
4591 return mlir::failure();
4592
4593 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4594 return emitOpError(
4595 "operations inside capture region must not have hint clause");
4596
4597 if (getFirstOp()->getAttr("memory_order") ||
4598 getSecondOp()->getAttr("memory_order"))
4599 return emitOpError(
4600 "operations inside capture region must not have memory_order clause");
4601 return success();
4602}
4603
4604//===----------------------------------------------------------------------===//
4605// CancelOp
4606//===----------------------------------------------------------------------===//
4607
4608void CancelOp::build(OpBuilder &builder, OperationState &state,
4609 const CancelOperands &clauses) {
4610 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4611}
4612
4614 Operation *parent = thisOp->getParentOp();
4615 while (parent) {
4616 if (parent->getDialect() == thisOp->getDialect())
4617 return parent;
4618 parent = parent->getParentOp();
4619 }
4620 return nullptr;
4621}
4622
4623LogicalResult CancelOp::verify() {
4624 ClauseCancellationConstructType cct = getCancelDirective();
4625 // The next OpenMP operation in the chain of parents
4626 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4627 if (!structuralParent)
4628 return emitOpError() << "Orphaned cancel construct";
4629
4630 if ((cct == ClauseCancellationConstructType::Parallel) &&
4631 !mlir::isa<ParallelOp>(structuralParent)) {
4632 return emitOpError() << "cancel parallel must appear "
4633 << "inside a parallel region";
4634 }
4635 if (cct == ClauseCancellationConstructType::Loop) {
4636 // structural parent will be omp.loop_nest, directly nested inside
4637 // omp.wsloop
4638 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4639
4640 if (!wsloopOp) {
4641 return emitOpError()
4642 << "cancel loop must appear inside a worksharing-loop region";
4643 }
4644 if (wsloopOp.getNowaitAttr()) {
4645 return emitError() << "A worksharing construct that is canceled "
4646 << "must not have a nowait clause";
4647 }
4648 if (wsloopOp.getOrderedAttr()) {
4649 return emitError() << "A worksharing construct that is canceled "
4650 << "must not have an ordered clause";
4651 }
4652
4653 } else if (cct == ClauseCancellationConstructType::Sections) {
4654 // structural parent will be an omp.section, directly nested inside
4655 // omp.sections
4656 auto sectionsOp =
4657 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4658 if (!sectionsOp) {
4659 return emitOpError() << "cancel sections must appear "
4660 << "inside a sections region";
4661 }
4662 if (sectionsOp.getNowait()) {
4663 return emitError() << "A sections construct that is canceled "
4664 << "must not have a nowait clause";
4665 }
4666 }
4667 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4668 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4669 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4670 return emitOpError() << "cancel taskgroup must appear "
4671 << "inside a task region";
4672 }
4673 return success();
4674}
4675
4676//===----------------------------------------------------------------------===//
4677// CancellationPointOp
4678//===----------------------------------------------------------------------===//
4679
4680void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4681 const CancellationPointOperands &clauses) {
4682 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4683}
4684
4685LogicalResult CancellationPointOp::verify() {
4686 ClauseCancellationConstructType cct = getCancelDirective();
4687 // The next OpenMP operation in the chain of parents
4688 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4689 if (!structuralParent)
4690 return emitOpError() << "Orphaned cancellation point";
4691
4692 if ((cct == ClauseCancellationConstructType::Parallel) &&
4693 !mlir::isa<ParallelOp>(structuralParent)) {
4694 return emitOpError() << "cancellation point parallel must appear "
4695 << "inside a parallel region";
4696 }
4697 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4698 // find the wsloop
4699 if ((cct == ClauseCancellationConstructType::Loop) &&
4700 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4701 return emitOpError() << "cancellation point loop must appear "
4702 << "inside a worksharing-loop region";
4703 }
4704 if ((cct == ClauseCancellationConstructType::Sections) &&
4705 !mlir::isa<omp::SectionOp>(structuralParent)) {
4706 return emitOpError() << "cancellation point sections must appear "
4707 << "inside a sections region";
4708 }
4709 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4710 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4711 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4712 return emitOpError() << "cancellation point taskgroup must appear "
4713 << "inside a task region";
4714 }
4715 return success();
4716}
4717
4718//===----------------------------------------------------------------------===//
4719// MapBoundsOp
4720//===----------------------------------------------------------------------===//
4721
4722LogicalResult MapBoundsOp::verify() {
4723 auto extent = getExtent();
4724 auto upperbound = getUpperBound();
4725 if (!extent && !upperbound)
4726 return emitError("expected extent or upperbound.");
4727 return success();
4728}
4729
4730void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4731 TypeRange /*result_types*/, StringAttr symName,
4732 TypeAttr type) {
4733 PrivateClauseOp::build(
4734 odsBuilder, odsState, symName, type,
4735 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4736 DataSharingClauseType::Private));
4737}
4738
4739LogicalResult PrivateClauseOp::verifyRegions() {
4740 Type argType = getArgType();
4741 auto verifyTerminator = [&](Operation *terminator,
4742 bool yieldsValue) -> LogicalResult {
4743 if (!terminator->getBlock()->getSuccessors().empty())
4744 return success();
4745
4746 if (!llvm::isa<YieldOp>(terminator))
4747 return mlir::emitError(terminator->getLoc())
4748 << "expected exit block terminator to be an `omp.yield` op.";
4749
4750 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4751 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4752
4753 if (!yieldsValue) {
4754 if (yieldedTypes.empty())
4755 return success();
4756
4757 return mlir::emitError(terminator->getLoc())
4758 << "Did not expect any values to be yielded.";
4759 }
4760
4761 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4762 return success();
4763
4764 auto error = mlir::emitError(yieldOp.getLoc())
4765 << "Invalid yielded value. Expected type: " << argType
4766 << ", got: ";
4767
4768 if (yieldedTypes.empty())
4769 error << "None";
4770 else
4771 error << yieldedTypes;
4772
4773 return error;
4774 };
4775
4776 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4777 StringRef regionName,
4778 bool yieldsValue) -> LogicalResult {
4779 assert(!region.empty());
4780
4781 if (region.getNumArguments() != expectedNumArgs)
4782 return mlir::emitError(region.getLoc())
4783 << "`" << regionName << "`: "
4784 << "expected " << expectedNumArgs
4785 << " region arguments, got: " << region.getNumArguments();
4786
4787 for (Block &block : region) {
4788 // MLIR will verify the absence of the terminator for us.
4789 if (!block.mightHaveTerminator())
4790 continue;
4791
4792 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4793 return failure();
4794 }
4795
4796 return success();
4797 };
4798
4799 // Ensure all of the region arguments have the same type
4800 for (Region *region : getRegions())
4801 for (Type ty : region->getArgumentTypes())
4802 if (ty != argType)
4803 return emitError() << "Region argument type mismatch: got " << ty
4804 << " expected " << argType << ".";
4805
4806 mlir::Region &initRegion = getInitRegion();
4807 if (!initRegion.empty() &&
4808 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4809 /*yieldsValue=*/true)))
4810 return failure();
4811
4812 DataSharingClauseType dsType = getDataSharingType();
4813
4814 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4815 return emitError("`private` clauses do not require a `copy` region.");
4816
4817 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4818 return emitError(
4819 "`firstprivate` clauses require at least a `copy` region.");
4820
4821 if (dsType == DataSharingClauseType::FirstPrivate &&
4822 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4823 /*yieldsValue=*/true)))
4824 return failure();
4825
4826 if (!getDeallocRegion().empty() &&
4827 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4828 /*yieldsValue=*/false)))
4829 return failure();
4830
4831 return success();
4832}
4833
4834//===----------------------------------------------------------------------===//
4835// Spec 5.2: Masked construct (10.5)
4836//===----------------------------------------------------------------------===//
4837
4838void MaskedOp::build(OpBuilder &builder, OperationState &state,
4839 const MaskedOperands &clauses) {
4840 MaskedOp::build(builder, state, clauses.filteredThreadId);
4841}
4842
4843//===----------------------------------------------------------------------===//
4844// Spec 5.2: Scan construct (5.6)
4845//===----------------------------------------------------------------------===//
4846
4847void ScanOp::build(OpBuilder &builder, OperationState &state,
4848 const ScanOperands &clauses) {
4849 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4850}
4851
4852LogicalResult ScanOp::verify() {
4853 if (hasExclusiveVars() == hasInclusiveVars())
4854 return emitError(
4855 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4856 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4857 if (parentWsLoopOp.getReductionModAttr() &&
4858 parentWsLoopOp.getReductionModAttr().getValue() ==
4859 ReductionModifier::inscan)
4860 return success();
4861 }
4862 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4863 if (parentSimdOp.getReductionModAttr() &&
4864 parentSimdOp.getReductionModAttr().getValue() ==
4865 ReductionModifier::inscan)
4866 return success();
4867 }
4868 return emitError("SCAN directive needs to be enclosed within a parent "
4869 "worksharing loop construct or SIMD construct with INSCAN "
4870 "reduction modifier");
4871}
4872
4873/// Verifies align clause in allocate directive
4874LogicalResult verifyAlignment(Operation &op,
4875 std::optional<uint64_t> alignment) {
4876 if (alignment.has_value()) {
4877 if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
4878 return op.emitError()
4879 << "ALIGN value : " << alignment.value() << " must be power of 2";
4880 }
4881 return success();
4882}
4883
4884LogicalResult AllocateDirOp::verify() {
4885 return verifyAlignment(*getOperation(), getAlign());
4886}
4887
4888//===----------------------------------------------------------------------===//
4889// AllocSharedMemOp
4890//===----------------------------------------------------------------------===//
4891
4892LogicalResult AllocSharedMemOp::verify() {
4893 return verifyAlignment(*getOperation(), getMemAlignment());
4894}
4895
4896//===----------------------------------------------------------------------===//
4897// FreeSharedMemOp
4898//===----------------------------------------------------------------------===//
4899
4900LogicalResult FreeSharedMemOp::verify() {
4901 return verifyAlignment(*getOperation(), getMemAlignment());
4902}
4903
4904//===----------------------------------------------------------------------===//
4905// WorkdistributeOp
4906//===----------------------------------------------------------------------===//
4907
4908LogicalResult WorkdistributeOp::verify() {
4909 // Check that region exists and is not empty
4910 Region &region = getRegion();
4911 if (region.empty())
4912 return emitOpError("region cannot be empty");
4913 // Verify single entry point.
4914 Block &entryBlock = region.front();
4915 if (entryBlock.empty())
4916 return emitOpError("region must contain a structured block");
4917 // Verify single exit point.
4918 bool hasTerminator = false;
4919 for (Block &block : region) {
4920 if (isa<TerminatorOp>(block.back())) {
4921 if (hasTerminator) {
4922 return emitOpError("region must have exactly one terminator");
4923 }
4924 hasTerminator = true;
4925 }
4926 }
4927 if (!hasTerminator) {
4928 return emitOpError("region must be terminated with omp.terminator");
4929 }
4930 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4931 // No implicit barrier at end
4932 if (isa<BarrierOp>(op)) {
4933 return emitOpError(
4934 "explicit barriers are not allowed in workdistribute region");
4935 }
4936 // Check for invalid nested constructs
4937 if (isa<ParallelOp>(op)) {
4938 return emitOpError(
4939 "nested parallel constructs not allowed in workdistribute");
4940 }
4941 if (isa<TeamsOp>(op)) {
4942 return emitOpError(
4943 "nested teams constructs not allowed in workdistribute");
4944 }
4945 return WalkResult::advance();
4946 });
4947 if (walkResult.wasInterrupted())
4948 return failure();
4949
4950 Operation *parentOp = (*this)->getParentOp();
4951 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4952 return emitOpError("workdistribute must be nested under teams");
4953 return success();
4954}
4955
4956//===----------------------------------------------------------------------===//
4957// Declare simd [7.7]
4958//===----------------------------------------------------------------------===//
4959
4960LogicalResult DeclareSimdOp::verify() {
4961 // Must be nested inside a function-like op
4962 auto func =
4963 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4964 if (!func)
4965 return emitOpError() << "must be nested inside a function";
4966
4967 if (getInbranch() && getNotinbranch())
4968 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
4969
4970 if (failed(verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars(),
4971 /*isDeclareSimd=*/true)))
4972 return failure();
4973
4974 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
4975}
4976
4977void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4978 const DeclareSimdOperands &clauses) {
4979 MLIRContext *ctx = odsBuilder.getContext();
4980 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4981 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
4982 clauses.linearVars, clauses.linearStepVars,
4983 clauses.linearVarTypes, clauses.linearModifiers,
4984 clauses.notinbranch, clauses.simdlen,
4985 clauses.uniformVars);
4986}
4987
4988//===----------------------------------------------------------------------===//
4989// Parser and printer for Uniform Clause
4990//===----------------------------------------------------------------------===//
4991
4992/// uniform ::= `uniform` `(` uniform-list `)`
4993/// uniform-list := uniform-val (`,` uniform-val)*
4994/// uniform-val := ssa-id `:` type
4995static ParseResult
4998 SmallVectorImpl<Type> &uniformTypes) {
4999 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
5000 if (parser.parseOperand(uniformVars.emplace_back()) ||
5001 parser.parseColonType(uniformTypes.emplace_back()))
5002 return mlir::failure();
5003 return mlir::success();
5004 });
5005}
5006
5007/// Print Uniform Clauses
5009 ValueRange uniformVars, TypeRange uniformTypes) {
5010 for (unsigned i = 0; i < uniformVars.size(); ++i) {
5011 if (i != 0)
5012 p << ", ";
5013 p << uniformVars[i] << " : " << uniformTypes[i];
5014 }
5015}
5016
5017//===----------------------------------------------------------------------===//
5018// Parser and printer for Affinity Clause
5019//===----------------------------------------------------------------------===//
5020
5021static ParseResult parseAffinityClause(
5022 OpAsmParser &parser,
5025 SmallVectorImpl<Type> &iteratedTypes,
5026 SmallVectorImpl<Type> &affinityVarTypes) {
5027 if (failed(parseSplitIteratedList(
5028 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
5029 /*parsePrefix=*/[&]() -> ParseResult { return success(); })))
5030 return failure();
5031 return success();
5032}
5033
5035 ValueRange iterated, ValueRange affinityVars,
5036 TypeRange iteratedTypes,
5037 TypeRange affinityVarTypes) {
5038 auto nop = [&](Value, Type) {};
5039 printSplitIteratedList(p, iterated, iteratedTypes, affinityVars,
5040 affinityVarTypes,
5041 /*plain prefix*/ nop,
5042 /*iterated prefix*/ nop);
5043}
5044
5045//===----------------------------------------------------------------------===//
5046// Parser, printer, and verifier for Iterator modifier
5047//===----------------------------------------------------------------------===//
5048
5049static ParseResult
5054 SmallVectorImpl<Type> &lbTypes,
5055 SmallVectorImpl<Type> &ubTypes,
5056 SmallVectorImpl<Type> &stepTypes) {
5057
5058 llvm::SMLoc ivLoc = parser.getCurrentLocation();
5060
5061 // Parse induction variables: %i : i32, %j : i32
5062 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
5063 OpAsmParser::Argument &arg = ivArgs.emplace_back();
5064 if (parser.parseArgument(arg))
5065 return failure();
5066
5067 // Optional type, default to Index if not provided
5068 if (succeeded(parser.parseOptionalColon())) {
5069 if (parser.parseType(arg.type))
5070 return failure();
5071 } else {
5072 arg.type = parser.getBuilder().getIndexType();
5073 }
5074 return success();
5075 }))
5076 return failure();
5077
5078 // ) = (
5079 if (parser.parseRParen() || parser.parseEqual() || parser.parseLParen())
5080 return failure();
5081
5082 // Parse Ranges: (%lb to %ub step %st, ...)
5083 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
5084 OpAsmParser::UnresolvedOperand lb, ub, st;
5085 if (parser.parseOperand(lb) || parser.parseKeyword("to") ||
5086 parser.parseOperand(ub) || parser.parseKeyword("step") ||
5087 parser.parseOperand(st))
5088 return failure();
5089
5090 lbs.push_back(lb);
5091 ubs.push_back(ub);
5092 steps.push_back(st);
5093 return success();
5094 }))
5095 return failure();
5096
5097 if (parser.parseRParen())
5098 return failure();
5099
5100 if (ivArgs.size() != lbs.size())
5101 return parser.emitError(ivLoc)
5102 << "mismatch: " << ivArgs.size() << " variables but " << lbs.size()
5103 << " ranges";
5104
5105 for (auto &arg : ivArgs) {
5106 lbTypes.push_back(arg.type);
5107 ubTypes.push_back(arg.type);
5108 stepTypes.push_back(arg.type);
5109 }
5110
5111 return parser.parseRegion(region, ivArgs);
5112}
5113
5115 ValueRange lbs, ValueRange ubs,
5117 TypeRange) {
5118 Block &entry = region.front();
5119
5120 for (unsigned i = 0, e = entry.getNumArguments(); i < e; ++i) {
5121 if (i != 0)
5122 p << ", ";
5123 p.printRegionArgument(entry.getArgument(i));
5124 }
5125 p << ") = (";
5126
5127 // (%lb0 to %ub0 step %step0, %lb1 to %ub1 step %step1, ...)
5128 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
5129 if (i)
5130 p << ", ";
5131 p << lbs[i] << " to " << ubs[i] << " step " << steps[i];
5132 }
5133 p << ") ";
5134
5135 p.printRegion(region, /*printEntryBlockArgs=*/false,
5136 /*printBlockTerminators=*/true);
5137}
5138
5139LogicalResult IteratorOp::verify() {
5140 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().getType());
5141 if (!iteratedTy)
5142 return emitOpError() << "result must be omp.iterated<entry_ty>";
5143
5144 for (auto [lb, ub, step] : llvm::zip_equal(
5145 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5146 if (matchPattern(step, m_Zero()))
5147 return emitOpError() << "loop step must not be zero";
5148
5149 IntegerAttr lbAttr;
5150 IntegerAttr ubAttr;
5151 IntegerAttr stepAttr;
5152 if (!matchPattern(lb, m_Constant(&lbAttr)) ||
5153 !matchPattern(ub, m_Constant(&ubAttr)) ||
5154 !matchPattern(step, m_Constant(&stepAttr)))
5155 continue;
5156
5157 const APInt &lbVal = lbAttr.getValue();
5158 const APInt &ubVal = ubAttr.getValue();
5159 const APInt &stepVal = stepAttr.getValue();
5160 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5161 return emitOpError() << "positive loop step requires lower bound to be "
5162 "less than or equal to upper bound";
5163 if (stepVal.isNegative() && lbVal.slt(ubVal))
5164 return emitOpError() << "negative loop step requires lower bound to be "
5165 "greater than or equal to upper bound";
5166 }
5167
5168 Block &b = getRegion().front();
5169 auto yield = llvm::dyn_cast<omp::YieldOp>(b.getTerminator());
5170
5171 if (!yield)
5172 return emitOpError() << "region must be terminated by omp.yield";
5173
5174 if (yield.getNumOperands() != 1)
5175 return emitOpError()
5176 << "omp.yield in omp.iterator region must yield exactly one value";
5177
5178 mlir::Type yieldedTy = yield.getOperand(0).getType();
5179 mlir::Type elemTy = iteratedTy.getElementType();
5180
5181 if (yieldedTy != elemTy)
5182 return emitOpError() << "omp.iterated element type (" << elemTy
5183 << ") does not match omp.yield operand type ("
5184 << yieldedTy << ")";
5185
5186 return success();
5187}
5188
5189//===----------------------------------------------------------------------===//
5190// GroupprivateOp
5191//===----------------------------------------------------------------------===//
5192
5193LogicalResult
5194GroupprivateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
5195 auto *symbol = symbolTable.lookupNearestSymbolFrom(*this, getSymNameAttr());
5196 if (!symbol)
5197 return emitOpError() << "expected symbol reference '" << getSymName()
5198 << "' to point to a global variable";
5199
5200 if (isa<FunctionOpInterface>(symbol))
5201 return emitOpError() << "expected symbol reference '" << getSymName()
5202 << "' to point to a global variable, not a function";
5203
5204 return success();
5205}
5206
5207#define GET_ATTRDEF_CLASSES
5208#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5209
5210#define GET_OP_CLASSES
5211#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5212
5213#define GET_TYPEDEF_CLASSES
5214#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1510
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static void printHeapAllocClause(OpAsmPrinter &p, Operation *op, TypeAttr inType, ValueRange typeparams, TypeRange typeparamsTypes, ValueRange shape, TypeRange shapeTypes)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op, AccessGroupModifierAttr modifierFirst, FallbackModifierAttr modifierSecond, Value dynGroupprivateSize, Type sizeType)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseHeapAllocClause(OpAsmParser &parser, TypeAttr &inTypeAttr, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &typeparams, SmallVectorImpl< Type > &typeparamsTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &shape, SmallVectorImpl< Type > &shapeTypes)
operation ::= $in_type ( ( $typeparams ) )? ( , $shape )?
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static ParseResult parseDynGroupprivateClause(OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr, FallbackModifierAttr &fallbackAttr, std::optional< OpAsmParser::UnresolvedOperand > &dynGroupprivateSize, Type &sizeType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup, FallbackModifierAttr fallback, Value dynGroupprivateSize)
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
LogicalResult verifyAlignment(Operation &op, std::optional< uint64_t > alignment)
Verifies align clause in allocate directive.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgListType getArguments()
Definition Block.h:97
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:143
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:161
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.