MLIR 23.0.0git
OpenACCCG.cpp
Go to the documentation of this file.
1//===- OpenACCCG.cpp - OpenACC codegen ops, attributes, and types ---------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation for OpenACC codegen operations, attributes, and types.
10// These correspond to the definitions in OpenACCCG*.td tablegen files
11// and are kept in a separate file because they do not represent direct mappings
12// of OpenACC language constructs; they are intermediate representations used
13// when decomposing and lowering primary `acc` dialect operations.
14//
15//===----------------------------------------------------------------------===//
16
23#include "mlir/IR/Region.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
29
30using namespace mlir;
31using namespace acc;
32
33namespace {
34
35/// Generic helper for single-region OpenACC ops that execute their body once
36/// and then continue after the operation with their results (if any).
37static void
41 if (point.isParent()) {
42 regions.push_back(RegionSuccessor(&region));
43 return;
44 }
45 regions.push_back(RegionSuccessor(op));
46}
47
49 RegionSuccessor successor) {
50 return successor.isOperation() ? ValueRange(op->getResults()) : ValueRange();
51}
52
53/// Remove empty acc.kernel_environment operations. If the operation has wait
54/// operands, create a acc.wait operation to preserve synchronization.
55struct RemoveEmptyKernelEnvironment
56 : public OpRewritePattern<acc::KernelEnvironmentOp> {
57 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
58
59 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
60 PatternRewriter &rewriter) const override {
61 assert(op->getNumRegions() == 1 && "expected op to have one region");
62
63 Block &block = op.getRegion().front();
64 if (!block.empty())
65 return failure();
66
67 // Remove empty kernel environment.
68 // Preserve synchronization by creating acc.wait operation if needed.
69 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
70 rewriter.replaceOpWithNewOp<acc::WaitOp>(
71 op, op.getWaitOperands(), /*asyncOperand=*/Value(),
72 op.getWaitDevnum(), /*async=*/nullptr, /*ifCond=*/Value());
73 else
74 rewriter.eraseOp(op);
75
76 return success();
77 }
78};
79
80static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
81 PatternRewriter &rewriter,
82 size_t numInput) {
83 const size_t numLaunch = op.getLaunchArgs().size();
84 op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
85 rewriter.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunch),
86 static_cast<int32_t>(numInput),
87 op.getStream() ? 1 : 0}));
88}
89
90struct ComputeRegionRemoveDuplicateArgs
91 : public OpRewritePattern<ComputeRegionOp> {
93
94 LogicalResult matchAndRewrite(ComputeRegionOp op,
95 PatternRewriter &rewriter) const override {
96 Block *body = op.getBody();
97 const size_t numLaunch = op.getLaunchArgs().size();
98 size_t numInput = op.getInputArgs().size();
99 assert(body->getNumArguments() == numLaunch + numInput &&
100 "region args mismatch");
101
102 bool mergedAny = false;
103 while (true) {
104 bool merged = false;
105 for (size_t j = 1; j < numInput && !merged; ++j) {
106 for (size_t i = 0; i < j; ++i) {
107 if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
108 op->getOperand(static_cast<unsigned>(numLaunch + j)))
109 continue;
110 unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
111 unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
112 rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
113 body->getArgument(keepIdx));
114 body->eraseArgument(dropIdx);
115 op->eraseOperand(dropIdx);
116 --numInput;
117 merged = true;
118 mergedAny = true;
119 break;
120 }
121 }
122 if (!merged)
123 break;
124 }
125
126 if (!mergedAny)
127 return failure();
128 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
129 return success();
130 }
131};
132
133struct ComputeRegionRemoveUnusedArgs
134 : public OpRewritePattern<ComputeRegionOp> {
136
137 LogicalResult matchAndRewrite(ComputeRegionOp op,
138 PatternRewriter &rewriter) const override {
139 Block *body = op.getBody();
140 const size_t numLaunch = op.getLaunchArgs().size();
141 size_t numInput = op.getInputArgs().size();
142 assert(body->getNumArguments() == numLaunch + numInput &&
143 "region args mismatch");
144
145 bool changed = false;
146 for (size_t k = numLaunch; k < numLaunch + numInput;) {
147 if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
148 ++k;
149 continue;
150 }
151 body->eraseArgument(static_cast<unsigned>(k));
152 op->eraseOperand(static_cast<unsigned>(k));
153 --numInput;
154 changed = true;
155 }
156
157 if (!changed)
158 return failure();
159 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
160 return success();
161 }
162};
163
164template <typename EffectTy>
165static void addOperandEffect(
167 &effects,
168 const MutableOperandRange &operand) {
169 for (unsigned i = 0, e = operand.size(); i < e; ++i)
170 effects.emplace_back(EffectTy::get(), &operand[i]);
171}
172
173template <typename EffectTy>
174static void addResultEffect(
176 &effects,
177 Value result) {
178 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
179}
180
181static int64_t gpuProcessorIndex(gpu::Processor p) {
182 switch (p) {
183 case gpu::Processor::Sequential:
184 return 0;
185 case gpu::Processor::ThreadX:
186 return 1;
187 case gpu::Processor::ThreadY:
188 return 2;
189 case gpu::Processor::ThreadZ:
190 return 3;
191 case gpu::Processor::BlockX:
192 return 4;
193 case gpu::Processor::BlockY:
194 return 5;
195 case gpu::Processor::BlockZ:
196 return 6;
197 }
198 llvm_unreachable("unhandled gpu::Processor");
199}
200
201static gpu::Processor indexToGpuProcessor(int64_t idx) {
202 switch (idx) {
203 case 0:
204 return gpu::Processor::Sequential;
205 case 1:
206 return gpu::Processor::ThreadX;
207 case 2:
208 return gpu::Processor::ThreadY;
209 case 3:
210 return gpu::Processor::ThreadZ;
211 case 4:
212 return gpu::Processor::BlockX;
213 case 5:
214 return gpu::Processor::BlockY;
215 case 6:
216 return gpu::Processor::BlockZ;
217 default:
218 return gpu::Processor::Sequential;
219 }
220}
221
222static GPUParallelDimAttr intToParDim(MLIRContext *context, int64_t dimInt) {
223 return GPUParallelDimAttr::get(
224 context, IntegerAttr::get(IndexType::get(context), dimInt));
225}
226
227static GPUParallelDimAttr processorParDim(MLIRContext *context,
228 gpu::Processor proc) {
229 return GPUParallelDimAttr::get(
230 context,
231 IntegerAttr::get(IndexType::get(context), gpuProcessorIndex(proc)));
232}
233
234static ParseResult parseProcessorValue(AsmParser &parser,
235 GPUParallelDimAttr &dim) {
236 std::string keyword;
237 llvm::SMLoc loc = parser.getCurrentLocation();
238 if (failed(parser.parseKeywordOrString(&keyword)))
239 return failure();
240 auto maybeProcessor = gpu::symbolizeProcessor(keyword);
241 if (!maybeProcessor)
242 return parser.emitError(loc)
243 << "expected one of ::mlir::gpu::Processor enum names";
244 dim = intToParDim(parser.getContext(), gpuProcessorIndex(*maybeProcessor));
245 return success();
246}
247
248static void printProcessorValue(AsmPrinter &printer,
249 const GPUParallelDimAttr &attr) {
250 gpu::Processor processor = indexToGpuProcessor(attr.getValue().getInt());
251 printer << gpu::stringifyProcessor(processor);
252}
253
254} // namespace
255
256//===----------------------------------------------------------------------===//
257// KernelEnvironmentOp
258//===----------------------------------------------------------------------===//
259
260void KernelEnvironmentOp::getSuccessorRegions(
262 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
263 regions);
264}
265
266ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
267 return getSingleRegionSuccessorInputs(getOperation(), successor);
268}
269
270void KernelEnvironmentOp::getCanonicalizationPatterns(
271 RewritePatternSet &results, MLIRContext *context) {
272 results.add<RemoveEmptyKernelEnvironment>(context);
273}
274
275/// Extract async for `clauseDeviceType`. Returns true if a clause was found.
276template <typename ComputeConstructT>
277static bool
278extractAsyncClause(ComputeConstructT computeConstruct,
279 DeviceType clauseDeviceType, MLIRContext *context,
280 std::optional<Value> &asyncOperand, UnitAttr &asyncOnly) {
281 if (computeConstruct.hasAsyncOnly(clauseDeviceType)) {
282 asyncOnly = UnitAttr::get(context);
283 return true;
284 }
285 if (Value asyncValue = computeConstruct.getAsyncValue(clauseDeviceType)) {
286 asyncOperand = asyncValue;
287 return true;
288 }
289 return false;
290}
291
292/// Extract wait for `clauseDeviceType`. Returns true if a clause was found.
293template <typename ComputeConstructT>
294static bool extractWaitClause(ComputeConstructT computeConstruct,
295 DeviceType clauseDeviceType, MLIRContext *context,
296 std::optional<Value> &waitDevnum,
297 SmallVectorImpl<Value> &waitOperands,
298 UnitAttr &waitOnly) {
299 if (computeConstruct.hasWaitOnly(clauseDeviceType)) {
300 waitOnly = UnitAttr::get(context);
301 return true;
302 }
303 Value devnum = computeConstruct.getWaitDevnum(clauseDeviceType);
304 auto waitValues = computeConstruct.getWaitValues(clauseDeviceType);
305 if (!devnum && waitValues.empty())
306 return false;
307 if (devnum)
308 waitDevnum = devnum;
309 waitOperands.append(waitValues.begin(), waitValues.end());
310 return true;
311}
312
313template <typename ComputeConstructT>
315 ComputeConstructT computeConstruct, DeviceType deviceType,
316 std::optional<Value> &asyncOperand, UnitAttr &asyncOnly,
317 std::optional<Value> &waitDevnum, SmallVectorImpl<Value> &waitOperands,
318 UnitAttr &waitOnly) {
319 MLIRContext *context = computeConstruct->getContext();
320
321 // Prefer device_type-specific clauses, then default ones.
322 if (!extractAsyncClause(computeConstruct, deviceType, context, asyncOperand,
323 asyncOnly)) {
324 if (deviceType != DeviceType::None)
325 extractAsyncClause(computeConstruct, DeviceType::None, context,
326 asyncOperand, asyncOnly);
327 }
328
329 if (!extractWaitClause(computeConstruct, deviceType, context, waitDevnum,
330 waitOperands, waitOnly)) {
331 if (deviceType != DeviceType::None)
332 extractWaitClause(computeConstruct, DeviceType::None, context, waitDevnum,
333 waitOperands, waitOnly);
334 }
335}
336
337template <typename ComputeConstructT>
338KernelEnvironmentOp
339KernelEnvironmentOp::createAndPopulate(ComputeConstructT computeConstruct,
340 DeviceType deviceType,
341 OpBuilder &builder) {
342 std::optional<Value> asyncOperand;
343 UnitAttr asyncOnly = nullptr;
344 std::optional<Value> waitDevnum;
345 SmallVector<Value> waitOperands;
346 UnitAttr waitOnly = nullptr;
347 populateKernelEnvironmentAsyncWait(computeConstruct, deviceType, asyncOperand,
348 asyncOnly, waitDevnum, waitOperands,
349 waitOnly);
350
351 auto kernelEnvironment = KernelEnvironmentOp::create(
352 builder, computeConstruct->getLoc(),
353 computeConstruct.getDataClauseOperands(), asyncOperand.value_or(Value()),
354 asyncOnly, waitDevnum.value_or(Value()), waitOperands, waitOnly);
355 Block &block = kernelEnvironment.getRegion().emplaceBlock();
356 builder.setInsertionPointToStart(&block);
357 return kernelEnvironment;
358}
359
360template KernelEnvironmentOp
361KernelEnvironmentOp::createAndPopulate<ParallelOp>(ParallelOp, DeviceType,
362 OpBuilder &);
363template KernelEnvironmentOp
364KernelEnvironmentOp::createAndPopulate<KernelsOp>(KernelsOp, DeviceType,
365 OpBuilder &);
366template KernelEnvironmentOp
367KernelEnvironmentOp::createAndPopulate<SerialOp>(SerialOp, DeviceType,
368 OpBuilder &);
369
370LogicalResult KernelEnvironmentOp::verify() {
371 if (getAsyncOnly() && getAsyncOperand())
372 return emitError("async-only cannot appear with async operand");
373 if (getWaitOnly() && (!getWaitOperands().empty() || getWaitDevnum()))
374 return emitError("wait-only cannot appear with wait operands or devnum");
375 return success();
376}
377
378//===----------------------------------------------------------------------===//
379// FirstprivateMapInitialOp
380//===----------------------------------------------------------------------===//
381
382LogicalResult FirstprivateMapInitialOp::verify() {
383 if (getDataClause() != acc::DataClause::acc_firstprivate)
384 return emitError("data clause associated with firstprivate operation must "
385 "match its intent");
386 if (!getVar())
387 return emitError("must have var operand");
388 if (!mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
389 !mlir::isa<mlir::acc::MappableType>(getVar().getType()))
390 return emitError("var must be mappable or pointer-like");
391 if (mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
392 getVarType() == getVar().getType())
393 return emitError("varType must capture the element type of var");
394 if (getModifiers() != acc::DataClauseModifier::none)
395 return emitError("no data clause modifiers are allowed");
396 return success();
397}
398
399void FirstprivateMapInitialOp::getEffects(
401 &effects) {
402 effects.emplace_back(MemoryEffects::Read::get(),
404 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
406}
407
408//===----------------------------------------------------------------------===//
409// ReductionInitOp
410//===----------------------------------------------------------------------===//
411
412void ReductionInitOp::getSuccessorRegions(
414 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
415 regions);
416}
417
418void ReductionInitOp::getRegionInvocationBounds(
419 ArrayRef<Attribute> operands,
420 SmallVectorImpl<InvocationBounds> &invocationBounds) {
421 invocationBounds.emplace_back(1, 1);
422}
423
424ValueRange ReductionInitOp::getSuccessorInputs(RegionSuccessor successor) {
425 return getSingleRegionSuccessorInputs(getOperation(), successor);
426}
427
428LogicalResult ReductionInitOp::verify() {
429 Block &block = getRegion().front();
430 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
431 if (yieldOp.getNumOperands() != 1)
432 return emitOpError(
433 "region must yield exactly one value (private storage)");
434 if (yieldOp.getOperand(0).getType() != getVar().getType())
435 return emitOpError("yielded value type must match var type");
436 }
437 return success();
438}
439
440//===----------------------------------------------------------------------===//
441// ReductionCombineRegionOp
442//===----------------------------------------------------------------------===//
443
444void ReductionCombineRegionOp::getSuccessorRegions(
446 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
447 regions);
448}
449
450void ReductionCombineRegionOp::getRegionInvocationBounds(
451 ArrayRef<Attribute> operands,
452 SmallVectorImpl<InvocationBounds> &invocationBounds) {
453 invocationBounds.emplace_back(1, 1);
454}
455
457ReductionCombineRegionOp::getSuccessorInputs(RegionSuccessor successor) {
458 return getSingleRegionSuccessorInputs(getOperation(), successor);
459}
460
461LogicalResult ReductionCombineRegionOp::verify() {
462 Block &block = getRegion().front();
463 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
464 if (yieldOp.getNumOperands() != 0)
465 return emitOpError("region must be terminated by acc.yield with no "
466 "operands");
467 }
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// ReductionAccumulateOp
473//===----------------------------------------------------------------------===//
474
475LogicalResult ReductionAccumulateOp::verify() {
476 Type valueType = getValue().getType();
477 auto ptrLikeTy = cast<PointerLikeType>(getMemref().getType());
478 Type elementType = ptrLikeTy.getElementType();
479 if (!elementType)
480 return emitOpError("pointer-like destination must have an element type");
481 if (elementType != valueType)
482 return emitOpError("pointer-like element type must match value type");
483 if (getParDims().getArray().empty())
484 return emitOpError("par_dims must specify at least one parallel dimension");
485 return success();
486}
487
488//===----------------------------------------------------------------------===//
489// ReductionAccumulateArrayOp
490//===----------------------------------------------------------------------===//
491
492LogicalResult ReductionAccumulateArrayOp::verify() {
493 if (getParDims().getArray().empty())
494 return emitOpError("par_dims must specify at least one parallel dimension");
495 return success();
496}
497
498//===----------------------------------------------------------------------===//
499// ReductionCombineOp
500//===----------------------------------------------------------------------===//
501
502void ReductionCombineOp::getEffects(
504 &effects) {
505 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemrefMutable(),
507 effects.emplace_back(MemoryEffects::Read::get(), &getDestMemrefMutable(),
509 effects.emplace_back(MemoryEffects::Write::get(), &getDestMemrefMutable(),
511}
512
513//===----------------------------------------------------------------------===//
514// ComputeRegionOp
515//===----------------------------------------------------------------------===//
516
517static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op,
518 GPUParallelDimAttr parDim) {
519 for (auto launchArg : op.getLaunchArgs()) {
520 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
521 if (!parOp)
522 continue;
523 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
524 if (launchArgDim == parDim)
525 return parOp;
526 }
527 return nullptr;
528}
529
530std::optional<Value> ComputeRegionOp::getLaunchArg(GPUParallelDimAttr parDim) {
531 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
532 return parWidthOp.getResult();
533 return {};
534}
535
536std::optional<Value>
537ComputeRegionOp::getKnownLaunchArg(GPUParallelDimAttr parDim) {
538 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
539 if (parWidthOp.getLaunchArg())
540 return parWidthOp.getLaunchArg();
541 return {};
542}
543
544std::optional<uint64_t>
545ComputeRegionOp::getKnownConstantLaunchArg(GPUParallelDimAttr parDim) {
546 auto knownParWidth = getKnownLaunchArg(parDim);
547 if (knownParWidth.has_value())
548 return getConstantIntValue(knownParWidth.value());
549 return {};
550}
551
552BlockArgument ComputeRegionOp::appendInputArg(Value value) {
553 getInputArgsMutable().append(value);
554 return getBody()->addArgument(value.getType(), getLoc());
555}
556
557std::optional<BlockArgument>
558ComputeRegionOp::wireHoistedValueThroughIns(Value value) {
559 Region &region = getRegion();
560
561 auto useIsInRegion = [&](OpOperand &use) -> bool {
562 return region.isAncestor(use.getOwner()->getParentRegion());
563 };
564
565 if (!areValuesDefinedAbove(ValueRange(value), region) ||
566 !llvm::any_of(value.getUses(), useIsInRegion))
567 return std::nullopt;
568
569 BlockArgument arg = appendInputArg(value);
570 replaceAllUsesInRegionWith(value, arg, region);
571 return arg;
572}
573
574bool ComputeRegionOp::isEffectivelySerial() {
575 auto *ctx = getContext();
576
577 if (getLaunchArg(GPUParallelDimAttr::seqDim(ctx)))
578 return true;
579
580 auto checkDim = [&](GPUParallelDimAttr dim) -> bool {
581 auto val = getKnownConstantLaunchArg(dim);
582 return val && *val == 1;
583 };
584
585 return checkDim(GPUParallelDimAttr::threadXDim(ctx)) &&
586 checkDim(GPUParallelDimAttr::threadYDim(ctx)) &&
587 checkDim(GPUParallelDimAttr::threadZDim(ctx)) &&
588 checkDim(GPUParallelDimAttr::blockXDim(ctx)) &&
589 checkDim(GPUParallelDimAttr::blockYDim(ctx)) &&
590 checkDim(GPUParallelDimAttr::blockZDim(ctx));
591}
592
593BlockArgument ComputeRegionOp::parDimToWidth(GPUParallelDimAttr parDim) {
594 for (auto [pos, launchArg] : llvm::enumerate(getLaunchArgs())) {
595 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
596 assert(parOp);
597 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
598 if (launchArgDim == parDim) {
599 assert(pos < getRegion().front().getNumArguments() &&
600 "launch arg position out of range");
601 return getRegion().front().getArgument(pos);
602 }
603 }
604 llvm_unreachable("attempting to get unspecified parDim");
605}
606
607SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
609 for (auto launchArg : getLaunchArgs()) {
610 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
611 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
612 int64_t dimInt = launchArgDim.getValue().getInt();
613 parDims.push_back(intToParDim(getContext(), dimInt));
614 }
615 return parDims;
616}
617
618Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
619 Block *body = getBody();
620 if (blockArg.getOwner() != body)
621 return Value();
622 unsigned argNumber = blockArg.getArgNumber();
623 unsigned numLaunchArgs = getLaunchArgs().size();
624 unsigned numInputArgs = getInputArgs().size();
625 if (argNumber >= numLaunchArgs + numInputArgs)
626 return Value();
627 if (argNumber < numLaunchArgs)
628 return getLaunchArgs()[argNumber];
629 return getInputArgs()[argNumber - numLaunchArgs];
630}
631
632std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
633 Block *body = getBody();
634 for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
635 if (launchVal == value)
636 return body->getArgument(idx);
637 }
638 unsigned numLaunch = getLaunchArgs().size();
639 for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
640 if (inputVal == value)
641 return body->getArgument(numLaunch + idx);
642 }
643 return std::nullopt;
644}
645
646void ComputeRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
647 MLIRContext *context) {
648 results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
649 context);
650}
651
652BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
653 return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
654}
655
656LogicalResult ComputeRegionOp::verify() {
657 for (auto op : getLaunchArgs())
658 if (!op.getDefiningOp<acc::ParWidthOp>())
659 return emitOpError(
660 "launch arguments must be results of acc.par_width operations");
661
662 unsigned expectedBlockArgs = getLaunchArgs().size() + getInputArgs().size();
663 unsigned actualBlockArgs = getRegion().front().getNumArguments();
664 if (expectedBlockArgs != actualBlockArgs)
665 return emitOpError("expected ")
666 << expectedBlockArgs << " block arguments (launch + input), got "
667 << actualBlockArgs;
668
669 return success();
670}
671
672void ComputeRegionOp::print(OpAsmPrinter &p) {
673 ValueRange regionArgs = getBody()->getArguments();
674 ValueRange launchArgs = getLaunchArgs();
675 ValueRange inputArgs = getInputArgs();
676
677 assert(regionArgs.size() == (launchArgs.size() + inputArgs.size()) &&
678 "region args mismatch");
679
680 if (getStream())
681 p << " stream(" << getStream() << " : " << getStream().getType() << ")";
682
683 size_t i = 0;
684 if (!launchArgs.empty()) {
685 p << " launch(";
686 for (size_t j = 0; j < launchArgs.size(); ++j, ++i) {
687 p << regionArgs[i] << " = " << launchArgs[j];
688 if (j < launchArgs.size() - 1)
689 p << ", ";
690 }
691 p << ")";
692 }
693 if (!inputArgs.empty()) {
694 p << " ins(";
695 for (size_t j = 0; j < inputArgs.size(); ++j, ++i) {
696 p << regionArgs[i] << " = " << inputArgs[j];
697 if (j < inputArgs.size() - 1)
698 p << ", ";
699 }
700 p << ") : (";
701 for (size_t j = 0; j < inputArgs.size(); ++j) {
702 p << inputArgs[j].getType();
703 if (j < inputArgs.size() - 1)
704 p << ", ";
705 }
706 p << ")";
707 }
708 p.printOptionalArrowTypeList(getResultTypes());
709 p << " ";
710 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
711 p.printOptionalAttrDict((*this)->getAttrs(),
712 /*elidedAttrs=*/getOperandSegmentSizeAttr());
713}
714
715ParseResult ComputeRegionOp::parse(OpAsmParser &parser,
717 auto &builder = parser.getBuilder();
718
720 OpAsmParser::UnresolvedOperand streamOperand;
721 Type streamType;
724 SmallVector<Type> types;
725
726 bool hasStream = false;
727 if (succeeded(parser.parseOptionalKeyword("stream"))) {
728 hasStream = true;
729 if (parser.parseLParen() || parser.parseOperand(streamOperand) ||
730 parser.parseColon() || parser.parseType(streamType) ||
731 parser.parseRParen())
732 return failure();
733 }
734
735 if (succeeded(parser.parseOptionalKeyword("launch"))) {
736 if (parser.parseAssignmentList(regionArgs, launchOperands))
737 return failure();
738 Type indexType = builder.getIndexType();
739 for (size_t i = 0; i < regionArgs.size(); ++i)
740 types.push_back(indexType);
741 }
742
743 if (succeeded(parser.parseOptionalKeyword("ins"))) {
744 if (parser.parseAssignmentList(regionArgs, inputOperands) ||
745 parser.parseColon() || parser.parseLParen() ||
746 parser.parseTypeList(types) || parser.parseRParen())
747 return failure();
748 }
749
750 if (parser.parseOptionalArrowTypeList(result.types))
751 return failure();
752
753 for (auto [iterArg, type] : llvm::zip_equal(regionArgs, types))
754 iterArg.type = type;
755
756 Region *body = result.addRegion();
757 if (parser.parseRegion(*body, regionArgs))
758 return failure();
759 ComputeRegionOp::ensureTerminator(*body, parser.getBuilder(),
760 result.location);
761
762 const size_t numLaunchOperands = launchOperands.size();
763 const size_t numInputOperands = inputOperands.size();
764 assert(numLaunchOperands + numInputOperands == regionArgs.size() &&
765 "compute region args mismatch");
766
767 result.addAttribute(
768 ComputeRegionOp::getOperandSegmentSizeAttr(),
769 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunchOperands),
770 static_cast<int32_t>(numInputOperands),
771 hasStream ? 1 : 0}));
772
773 for (size_t i = 0; i < numLaunchOperands; ++i) {
774 if (parser.resolveOperand(launchOperands[i], types[i], result.operands))
775 return failure();
776 }
777
778 for (size_t i = numLaunchOperands; i < regionArgs.size(); ++i) {
779 if (parser.resolveOperand(inputOperands[i - numLaunchOperands], types[i],
780 result.operands))
781 return failure();
782 }
783
784 if (hasStream) {
785 if (parser.resolveOperand(streamOperand, streamType, result.operands))
786 return failure();
787 }
788
789 if (parser.parseOptionalAttrDict(result.attributes))
790 return failure();
791
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// PredicateRegionOp
797//===----------------------------------------------------------------------===//
798
799LogicalResult PredicateRegionOp::verify() {
800 if (getRegion().empty())
801 return emitOpError("region needs to have at least one block");
802 if (getRegion().front().getNumArguments() > 0)
803 return emitOpError("region cannot have any arguments");
804 if (!getOperation()->getParentOfType<ComputeRegionOp>())
805 return emitOpError("must be nested within an acc.compute_region operation");
806 return success();
807}
808
809//===----------------------------------------------------------------------===//
810// GPUParallelDimAttr
811//===----------------------------------------------------------------------===//
812
813GPUParallelDimAttr GPUParallelDimAttr::get(MLIRContext *context,
814 gpu::Processor proc) {
815 return processorParDim(context, proc);
816}
817
818GPUParallelDimAttr GPUParallelDimAttr::seqDim(MLIRContext *context) {
819 return processorParDim(context, gpu::Processor::Sequential);
820}
821
822GPUParallelDimAttr GPUParallelDimAttr::threadXDim(MLIRContext *context) {
823 return processorParDim(context, gpu::Processor::ThreadX);
824}
825
826GPUParallelDimAttr GPUParallelDimAttr::threadYDim(MLIRContext *context) {
827 return processorParDim(context, gpu::Processor::ThreadY);
828}
829
830GPUParallelDimAttr GPUParallelDimAttr::threadZDim(MLIRContext *context) {
831 return processorParDim(context, gpu::Processor::ThreadZ);
832}
833
834GPUParallelDimAttr GPUParallelDimAttr::blockXDim(MLIRContext *context) {
835 return processorParDim(context, gpu::Processor::BlockX);
836}
837
838GPUParallelDimAttr GPUParallelDimAttr::blockYDim(MLIRContext *context) {
839 return processorParDim(context, gpu::Processor::BlockY);
840}
841
842GPUParallelDimAttr GPUParallelDimAttr::blockZDim(MLIRContext *context) {
843 return processorParDim(context, gpu::Processor::BlockZ);
844}
845
846Attribute GPUParallelDimAttr::parse(AsmParser &parser, Type type) {
847 GPUParallelDimAttr dim;
848 if (parser.parseLess() || parseProcessorValue(parser, dim) ||
849 parser.parseGreater()) {
850 parser.emitError(parser.getCurrentLocation(),
851 "expected format `<` processor_name `>`");
852 return {};
853 }
854 return dim;
855}
856
857void GPUParallelDimAttr::print(AsmPrinter &printer) const {
858 printer << "<";
859 printProcessorValue(printer, *this);
860 printer << ">";
861}
862
863GPUParallelDimAttr GPUParallelDimAttr::threadDim(MLIRContext *context,
864 unsigned index) {
865 assert(index <= 2 && "thread dimension index must be 0, 1, or 2");
866 switch (index) {
867 case 0:
868 return threadXDim(context);
869 case 1:
870 return threadYDim(context);
871 case 2:
872 return threadZDim(context);
873 }
874 llvm_unreachable("validated thread dimension index");
875}
876
877GPUParallelDimAttr GPUParallelDimAttr::blockDim(MLIRContext *context,
878 unsigned index) {
879 assert(index <= 2 && "block dimension index must be 0, 1, or 2");
880 switch (index) {
881 case 0:
882 return blockXDim(context);
883 case 1:
884 return blockYDim(context);
885 case 2:
886 return blockZDim(context);
887 }
888 llvm_unreachable("validated block dimension index");
889}
890
891gpu::Processor GPUParallelDimAttr::getProcessor() const {
892 return indexToGpuProcessor(getValue().getInt());
893}
894
895int GPUParallelDimAttr::getOrder() const {
896 return gpuProcessorIndex(getProcessor());
897}
898
899GPUParallelDimAttr GPUParallelDimAttr::getOneHigher() const {
900 int order = getOrder();
901 if (order >= 6) // BlockZ is the highest
902 return *this;
903 return get(getContext(), indexToGpuProcessor(order + 1));
904}
905
906GPUParallelDimAttr GPUParallelDimAttr::getOneLower() const {
907 int order = getOrder();
908 if (order <= 0) // Sequential is the lowest
909 return *this;
910 return get(getContext(), indexToGpuProcessor(order - 1));
911}
912
913bool GPUParallelDimAttr::isSeq() const {
914 return getProcessor() == gpu::Processor::Sequential;
915}
916bool GPUParallelDimAttr::isThreadX() const {
917 return getProcessor() == gpu::Processor::ThreadX;
918}
919bool GPUParallelDimAttr::isThreadY() const {
920 return getProcessor() == gpu::Processor::ThreadY;
921}
922bool GPUParallelDimAttr::isThreadZ() const {
923 return getProcessor() == gpu::Processor::ThreadZ;
924}
925bool GPUParallelDimAttr::isBlockX() const {
926 return getProcessor() == gpu::Processor::BlockX;
927}
928bool GPUParallelDimAttr::isBlockY() const {
929 return getProcessor() == gpu::Processor::BlockY;
930}
931bool GPUParallelDimAttr::isBlockZ() const {
932 return getProcessor() == gpu::Processor::BlockZ;
933}
934bool GPUParallelDimAttr::isAnyThread() const {
935 return isThreadX() || isThreadY() || isThreadZ();
936}
937bool GPUParallelDimAttr::isAnyBlock() const {
938 return isBlockX() || isBlockY() || isBlockZ();
939}
940
941//===----------------------------------------------------------------------===//
942// GPUParallelDimsAttr
943//===----------------------------------------------------------------------===//
944
945GPUParallelDimsAttr GPUParallelDimsAttr::seq(MLIRContext *ctx) {
946 return GPUParallelDimsAttr::get(ctx, {GPUParallelDimAttr::seqDim(ctx)});
947}
948
949bool GPUParallelDimsAttr::isSeq() const {
950 assert(!getArray().empty() && "no par_dims found");
951 if (getArray().size() == 1) {
952 auto parDim = dyn_cast<GPUParallelDimAttr>(getArray()[0]);
953 assert(parDim && "expected GPUParallelDimAttr");
954 return parDim.isSeq();
955 }
956 return false;
957}
958
959bool GPUParallelDimsAttr::isParallel() const { return !isSeq(); }
960
961bool GPUParallelDimsAttr::isMultiDim() const { return getArray().size() > 1; }
962
963bool GPUParallelDimsAttr::hasAnyBlockLevel() const {
964 return llvm::any_of(
965 getArray(), [](const GPUParallelDimAttr &p) { return p.isAnyBlock(); });
966}
967
968bool GPUParallelDimsAttr::hasOnlyBlockLevel() const {
969 return !getArray().empty() &&
970 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
971 return p.isAnyBlock();
972 });
973}
974
975bool GPUParallelDimsAttr::hasOnlyThreadYLevel() const {
976 return !getArray().empty() &&
977 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
978 return p.isThreadY();
979 });
980}
981
982bool GPUParallelDimsAttr::hasOnlyThreadXLevel() const {
983 return !getArray().empty() &&
984 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
985 return p.isThreadX();
986 });
987}
988
989Attribute GPUParallelDimsAttr::parse(AsmParser &parser, Type type) {
990 auto delimiter = AsmParser::Delimiter::Square;
992 auto parseParDim = [&]() -> ParseResult {
993 GPUParallelDimAttr dim;
994 if (parseProcessorValue(parser, dim))
995 return failure();
996 parDims.push_back(dim);
997 return success();
998 };
999 if (parser.parseCommaSeparatedList(delimiter, parseParDim,
1000 "list of OpenACC GPU parallel dimensions"))
1001 return {};
1002 return GPUParallelDimsAttr::get(parser.getContext(), parDims);
1003}
1004
1005void GPUParallelDimsAttr::print(AsmPrinter &printer) const {
1006 printer << "[";
1007 llvm::interleaveComma(getArray(), printer,
1008 [&printer](const GPUParallelDimAttr &p) {
1009 printProcessorValue(printer, p);
1010 });
1011 printer << "]";
1012}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1331
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1341
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then continue after the...
Definition OpenACC.cpp:525
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:536
b getContext())
static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op, GPUParallelDimAttr parDim)
static bool extractWaitClause(ComputeConstructT computeConstruct, DeviceType clauseDeviceType, MLIRContext *context, std::optional< Value > &waitDevnum, SmallVectorImpl< Value > &waitOperands, UnitAttr &waitOnly)
Extract wait for clauseDeviceType. Returns true if a clause was found.
static bool extractAsyncClause(ComputeConstructT computeConstruct, DeviceType clauseDeviceType, MLIRContext *context, std::optional< Value > &asyncOperand, UnitAttr &asyncOnly)
Extract async for clauseDeviceType. Returns true if a clause was found.
static void populateKernelEnvironmentAsyncWait(ComputeConstructT computeConstruct, DeviceType deviceType, std::optional< Value > &asyncOperand, UnitAttr &asyncOnly, std::optional< Value > &waitDevnum, SmallVectorImpl< Value > &waitOperands, UnitAttr &waitOnly)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:315
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:198
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
result_range getResults()
Definition Operation.h:440
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
bool isOperation() const
Return true if the successor is an operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5248
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5217
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5321
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5303
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5225
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.