MLIR 23.0.0git
OpenACCCG.cpp
Go to the documentation of this file.
1//===- OpenACCCG.cpp - OpenACC codegen ops, attributes, and types ---------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation for OpenACC codegen operations, attributes, and types.
10// These correspond to the definitions in OpenACCCG*.td tablegen files
11// and are kept in a separate file because they do not represent direct mappings
12// of OpenACC language constructs; they are intermediate representations used
13// when decomposing and lowering primary `acc` dialect operations.
14//
15//===----------------------------------------------------------------------===//
16
23#include "mlir/IR/Region.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
29
30using namespace mlir;
31using namespace acc;
32
33namespace {
34
35/// Generic helper for single-region OpenACC ops that execute their body once
36/// and then return to the parent operation with their results (if any).
37static void
41 if (point.isParent()) {
42 regions.push_back(RegionSuccessor(&region));
43 return;
44 }
45 regions.push_back(RegionSuccessor::parent());
46}
47
49 RegionSuccessor successor) {
50 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
51}
52
53/// Remove empty acc.kernel_environment operations. If the operation has wait
54/// operands, create a acc.wait operation to preserve synchronization.
55struct RemoveEmptyKernelEnvironment
56 : public OpRewritePattern<acc::KernelEnvironmentOp> {
57 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
58
59 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
60 PatternRewriter &rewriter) const override {
61 assert(op->getNumRegions() == 1 && "expected op to have one region");
62
63 Block &block = op.getRegion().front();
64 if (!block.empty())
65 return failure();
66
67 // Conservatively disable canonicalization of empty acc.kernel_environment
68 // operations if the wait operands in the kernel_environment cannot be fully
69 // represented by acc.wait operation.
70
71 // Disable canonicalization if device type is not the default
72 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
73 for (auto attr : deviceTypeAttr) {
74 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
75 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
76 return failure();
77 }
78 }
79 }
80
81 // Disable canonicalization if any wait segment has a devnum
82 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
83 for (auto attr : hasDevnumAttr) {
84 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
85 if (boolAttr.getValue())
86 return failure();
87 }
88 }
89 }
90
91 // Disable canonicalization if there are multiple wait segments
92 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
93 if (segmentsAttr.size() > 1)
94 return failure();
95 }
96
97 // Remove empty kernel environment.
98 // Preserve synchronization by creating acc.wait operation if needed.
99 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
100 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
101 /*asyncOperand=*/Value(),
102 /*waitDevnum=*/Value(),
103 /*async=*/nullptr,
104 /*ifCond=*/Value());
105 else
106 rewriter.eraseOp(op);
107
108 return success();
109 }
110};
111
112static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
113 PatternRewriter &rewriter,
114 size_t numInput) {
115 const size_t numLaunch = op.getLaunchArgs().size();
116 op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
117 rewriter.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunch),
118 static_cast<int32_t>(numInput),
119 op.getStream() ? 1 : 0}));
120}
121
122struct ComputeRegionRemoveDuplicateArgs
123 : public OpRewritePattern<ComputeRegionOp> {
125
126 LogicalResult matchAndRewrite(ComputeRegionOp op,
127 PatternRewriter &rewriter) const override {
128 Block *body = op.getBody();
129 const size_t numLaunch = op.getLaunchArgs().size();
130 size_t numInput = op.getInputArgs().size();
131 assert(body->getNumArguments() == numLaunch + numInput &&
132 "region args mismatch");
133
134 bool mergedAny = false;
135 while (true) {
136 bool merged = false;
137 for (size_t j = 1; j < numInput && !merged; ++j) {
138 for (size_t i = 0; i < j; ++i) {
139 if (op->getOperand(static_cast<unsigned>(numLaunch + i)) !=
140 op->getOperand(static_cast<unsigned>(numLaunch + j)))
141 continue;
142 unsigned keepIdx = static_cast<unsigned>(numLaunch + i);
143 unsigned dropIdx = static_cast<unsigned>(numLaunch + j);
144 rewriter.replaceAllUsesWith(body->getArgument(dropIdx),
145 body->getArgument(keepIdx));
146 body->eraseArgument(dropIdx);
147 op->eraseOperand(dropIdx);
148 --numInput;
149 merged = true;
150 mergedAny = true;
151 break;
152 }
153 }
154 if (!merged)
155 break;
156 }
157
158 if (!mergedAny)
159 return failure();
160 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
161 return success();
162 }
163};
164
165struct ComputeRegionRemoveUnusedArgs
166 : public OpRewritePattern<ComputeRegionOp> {
168
169 LogicalResult matchAndRewrite(ComputeRegionOp op,
170 PatternRewriter &rewriter) const override {
171 Block *body = op.getBody();
172 const size_t numLaunch = op.getLaunchArgs().size();
173 size_t numInput = op.getInputArgs().size();
174 assert(body->getNumArguments() == numLaunch + numInput &&
175 "region args mismatch");
176
177 bool changed = false;
178 for (size_t k = numLaunch; k < numLaunch + numInput;) {
179 if (!body->getArgument(static_cast<unsigned>(k)).use_empty()) {
180 ++k;
181 continue;
182 }
183 body->eraseArgument(static_cast<unsigned>(k));
184 op->eraseOperand(static_cast<unsigned>(k));
185 --numInput;
186 changed = true;
187 }
188
189 if (!changed)
190 return failure();
191 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
192 return success();
193 }
194};
195
196template <typename EffectTy>
197static void addOperandEffect(
199 &effects,
200 const MutableOperandRange &operand) {
201 for (unsigned i = 0, e = operand.size(); i < e; ++i)
202 effects.emplace_back(EffectTy::get(), &operand[i]);
203}
204
205template <typename EffectTy>
206static void addResultEffect(
208 &effects,
209 Value result) {
210 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
211}
212
213static int64_t gpuProcessorIndex(gpu::Processor p) {
214 switch (p) {
215 case gpu::Processor::Sequential:
216 return 0;
217 case gpu::Processor::ThreadX:
218 return 1;
219 case gpu::Processor::ThreadY:
220 return 2;
221 case gpu::Processor::ThreadZ:
222 return 3;
223 case gpu::Processor::BlockX:
224 return 4;
225 case gpu::Processor::BlockY:
226 return 5;
227 case gpu::Processor::BlockZ:
228 return 6;
229 }
230 llvm_unreachable("unhandled gpu::Processor");
231}
232
233static gpu::Processor indexToGpuProcessor(int64_t idx) {
234 switch (idx) {
235 case 0:
236 return gpu::Processor::Sequential;
237 case 1:
238 return gpu::Processor::ThreadX;
239 case 2:
240 return gpu::Processor::ThreadY;
241 case 3:
242 return gpu::Processor::ThreadZ;
243 case 4:
244 return gpu::Processor::BlockX;
245 case 5:
246 return gpu::Processor::BlockY;
247 case 6:
248 return gpu::Processor::BlockZ;
249 default:
250 return gpu::Processor::Sequential;
251 }
252}
253
254static GPUParallelDimAttr intToParDim(MLIRContext *context, int64_t dimInt) {
255 return GPUParallelDimAttr::get(
256 context, IntegerAttr::get(IndexType::get(context), dimInt));
257}
258
259static GPUParallelDimAttr processorParDim(MLIRContext *context,
260 gpu::Processor proc) {
261 return GPUParallelDimAttr::get(
262 context,
263 IntegerAttr::get(IndexType::get(context), gpuProcessorIndex(proc)));
264}
265
266static ParseResult parseProcessorValue(AsmParser &parser,
267 GPUParallelDimAttr &dim) {
268 std::string keyword;
269 llvm::SMLoc loc = parser.getCurrentLocation();
270 if (failed(parser.parseKeywordOrString(&keyword)))
271 return failure();
272 auto maybeProcessor = gpu::symbolizeProcessor(keyword);
273 if (!maybeProcessor)
274 return parser.emitError(loc)
275 << "expected one of ::mlir::gpu::Processor enum names";
276 dim = intToParDim(parser.getContext(), gpuProcessorIndex(*maybeProcessor));
277 return success();
278}
279
280static void printProcessorValue(AsmPrinter &printer,
281 const GPUParallelDimAttr &attr) {
282 gpu::Processor processor = indexToGpuProcessor(attr.getValue().getInt());
283 printer << gpu::stringifyProcessor(processor);
284}
285
286} // namespace
287
288//===----------------------------------------------------------------------===//
289// KernelEnvironmentOp
290//===----------------------------------------------------------------------===//
291
292void KernelEnvironmentOp::getSuccessorRegions(
294 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
295 regions);
296}
297
298ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
299 return getSingleRegionSuccessorInputs(getOperation(), successor);
300}
301
302void KernelEnvironmentOp::getCanonicalizationPatterns(
303 RewritePatternSet &results, MLIRContext *context) {
304 results.add<RemoveEmptyKernelEnvironment>(context);
305}
306
307template <typename ComputeConstructT>
308KernelEnvironmentOp
309KernelEnvironmentOp::createAndPopulate(ComputeConstructT computeConstruct,
310 OpBuilder &builder) {
311 auto kernelEnvironment = KernelEnvironmentOp::create(
312 builder, computeConstruct->getLoc(),
313 computeConstruct.getDataClauseOperands(),
314 computeConstruct.getAsyncOperands(),
315 computeConstruct.getAsyncOperandsDeviceTypeAttr(),
316 computeConstruct.getAsyncOnlyAttr(), computeConstruct.getWaitOperands(),
317 computeConstruct.getWaitOperandsSegmentsAttr(),
318 computeConstruct.getWaitOperandsDeviceTypeAttr(),
319 computeConstruct.getHasWaitDevnumAttr(),
320 computeConstruct.getWaitOnlyAttr());
321 Block &block = kernelEnvironment.getRegion().emplaceBlock();
322 builder.setInsertionPointToStart(&block);
323 return kernelEnvironment;
324}
325
326template KernelEnvironmentOp
327KernelEnvironmentOp::createAndPopulate<ParallelOp>(ParallelOp, OpBuilder &);
328template KernelEnvironmentOp
329KernelEnvironmentOp::createAndPopulate<KernelsOp>(KernelsOp, OpBuilder &);
330template KernelEnvironmentOp
331KernelEnvironmentOp::createAndPopulate<SerialOp>(SerialOp, OpBuilder &);
332
333//===----------------------------------------------------------------------===//
334// FirstprivateMapInitialOp
335//===----------------------------------------------------------------------===//
336
337LogicalResult FirstprivateMapInitialOp::verify() {
338 if (getDataClause() != acc::DataClause::acc_firstprivate)
339 return emitError("data clause associated with firstprivate operation must "
340 "match its intent");
341 if (!getVar())
342 return emitError("must have var operand");
343 if (!mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
344 !mlir::isa<mlir::acc::MappableType>(getVar().getType()))
345 return emitError("var must be mappable or pointer-like");
346 if (mlir::isa<mlir::acc::PointerLikeType>(getVar().getType()) &&
347 getVarType() == getVar().getType())
348 return emitError("varType must capture the element type of var");
349 if (getModifiers() != acc::DataClauseModifier::none)
350 return emitError("no data clause modifiers are allowed");
351 return success();
352}
353
354void FirstprivateMapInitialOp::getEffects(
356 &effects) {
357 effects.emplace_back(MemoryEffects::Read::get(),
359 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
361}
362
363//===----------------------------------------------------------------------===//
364// ReductionInitOp
365//===----------------------------------------------------------------------===//
366
367void ReductionInitOp::getSuccessorRegions(
369 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
370 regions);
371}
372
373void ReductionInitOp::getRegionInvocationBounds(
374 ArrayRef<Attribute> operands,
375 SmallVectorImpl<InvocationBounds> &invocationBounds) {
376 invocationBounds.emplace_back(1, 1);
377}
378
379ValueRange ReductionInitOp::getSuccessorInputs(RegionSuccessor successor) {
380 return getSingleRegionSuccessorInputs(getOperation(), successor);
381}
382
383LogicalResult ReductionInitOp::verify() {
384 Block &block = getRegion().front();
385 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
386 if (yieldOp.getNumOperands() != 1)
387 return emitOpError(
388 "region must yield exactly one value (private storage)");
389 if (yieldOp.getOperand(0).getType() != getVar().getType())
390 return emitOpError("yielded value type must match var type");
391 }
392 return success();
393}
394
395//===----------------------------------------------------------------------===//
396// ReductionCombineRegionOp
397//===----------------------------------------------------------------------===//
398
399void ReductionCombineRegionOp::getSuccessorRegions(
401 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
402 regions);
403}
404
405void ReductionCombineRegionOp::getRegionInvocationBounds(
406 ArrayRef<Attribute> operands,
407 SmallVectorImpl<InvocationBounds> &invocationBounds) {
408 invocationBounds.emplace_back(1, 1);
409}
410
412ReductionCombineRegionOp::getSuccessorInputs(RegionSuccessor successor) {
413 return getSingleRegionSuccessorInputs(getOperation(), successor);
414}
415
416LogicalResult ReductionCombineRegionOp::verify() {
417 Block &block = getRegion().front();
418 if (auto yieldOp = dyn_cast<acc::YieldOp>(block.getTerminator())) {
419 if (yieldOp.getNumOperands() != 0)
420 return emitOpError("region must be terminated by acc.yield with no "
421 "operands");
422 }
423 return success();
424}
425
426//===----------------------------------------------------------------------===//
427// ReductionCombineOp
428//===----------------------------------------------------------------------===//
429
430void ReductionCombineOp::getEffects(
432 &effects) {
433 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemrefMutable(),
435 effects.emplace_back(MemoryEffects::Read::get(), &getDestMemrefMutable(),
437 effects.emplace_back(MemoryEffects::Write::get(), &getDestMemrefMutable(),
439}
440
441//===----------------------------------------------------------------------===//
442// ComputeRegionOp
443//===----------------------------------------------------------------------===//
444
445static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op,
446 GPUParallelDimAttr parDim) {
447 for (auto launchArg : op.getLaunchArgs()) {
448 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
449 if (!parOp)
450 continue;
451 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
452 if (launchArgDim == parDim)
453 return parOp;
454 }
455 return nullptr;
456}
457
458std::optional<Value> ComputeRegionOp::getLaunchArg(GPUParallelDimAttr parDim) {
459 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
460 return parWidthOp.getResult();
461 return {};
462}
463
464std::optional<Value>
465ComputeRegionOp::getKnownLaunchArg(GPUParallelDimAttr parDim) {
466 if (auto parWidthOp = getParWidthOpForLaunchArg(*this, parDim))
467 if (parWidthOp.getLaunchArg())
468 return parWidthOp.getLaunchArg();
469 return {};
470}
471
472std::optional<uint64_t>
473ComputeRegionOp::getKnownConstantLaunchArg(GPUParallelDimAttr parDim) {
474 auto knownParWidth = getKnownLaunchArg(parDim);
475 if (knownParWidth.has_value())
476 return getConstantIntValue(knownParWidth.value());
477 return {};
478}
479
480BlockArgument ComputeRegionOp::appendInputArg(Value value) {
481 getInputArgsMutable().append(value);
482 return getBody()->addArgument(value.getType(), getLoc());
483}
484
485std::optional<BlockArgument>
486ComputeRegionOp::wireHoistedValueThroughIns(Value value) {
487 Region &region = getRegion();
488
489 auto useIsInRegion = [&](OpOperand &use) -> bool {
490 return region.isAncestor(use.getOwner()->getParentRegion());
491 };
492
493 if (!areValuesDefinedAbove(ValueRange(value), region) ||
494 !llvm::any_of(value.getUses(), useIsInRegion))
495 return std::nullopt;
496
497 BlockArgument arg = appendInputArg(value);
498 replaceAllUsesInRegionWith(value, arg, region);
499 return arg;
500}
501
502bool ComputeRegionOp::isEffectivelySerial() {
503 auto *ctx = getContext();
504
505 if (getLaunchArg(GPUParallelDimAttr::seqDim(ctx)))
506 return true;
507
508 auto checkDim = [&](GPUParallelDimAttr dim) -> bool {
509 auto val = getKnownConstantLaunchArg(dim);
510 return val && *val == 1;
511 };
512
513 return checkDim(GPUParallelDimAttr::threadXDim(ctx)) &&
514 checkDim(GPUParallelDimAttr::threadYDim(ctx)) &&
515 checkDim(GPUParallelDimAttr::threadZDim(ctx)) &&
516 checkDim(GPUParallelDimAttr::blockXDim(ctx)) &&
517 checkDim(GPUParallelDimAttr::blockYDim(ctx)) &&
518 checkDim(GPUParallelDimAttr::blockZDim(ctx));
519}
520
521BlockArgument ComputeRegionOp::parDimToWidth(GPUParallelDimAttr parDim) {
522 for (auto [pos, launchArg] : llvm::enumerate(getLaunchArgs())) {
523 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
524 assert(parOp);
525 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
526 if (launchArgDim == parDim) {
527 assert(pos < getRegion().front().getNumArguments() &&
528 "launch arg position out of range");
529 return getRegion().front().getArgument(pos);
530 }
531 }
532 llvm_unreachable("attempting to get unspecified parDim");
533}
534
535SmallVector<GPUParallelDimAttr> ComputeRegionOp::getLaunchParDims() {
537 for (auto launchArg : getLaunchArgs()) {
538 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
539 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
540 int64_t dimInt = launchArgDim.getValue().getInt();
541 parDims.push_back(intToParDim(getContext(), dimInt));
542 }
543 return parDims;
544}
545
546Value ComputeRegionOp::getOperand(BlockArgument blockArg) {
547 Block *body = getBody();
548 if (blockArg.getOwner() != body)
549 return Value();
550 unsigned argNumber = blockArg.getArgNumber();
551 unsigned numLaunchArgs = getLaunchArgs().size();
552 unsigned numInputArgs = getInputArgs().size();
553 if (argNumber >= numLaunchArgs + numInputArgs)
554 return Value();
555 if (argNumber < numLaunchArgs)
556 return getLaunchArgs()[argNumber];
557 return getInputArgs()[argNumber - numLaunchArgs];
558}
559
560std::optional<BlockArgument> ComputeRegionOp::getBlockArg(Value value) {
561 Block *body = getBody();
562 for (auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
563 if (launchVal == value)
564 return body->getArgument(idx);
565 }
566 unsigned numLaunch = getLaunchArgs().size();
567 for (auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
568 if (inputVal == value)
569 return body->getArgument(numLaunch + idx);
570 }
571 return std::nullopt;
572}
573
574void ComputeRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
575 MLIRContext *context) {
576 results.add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
577 context);
578}
579
580BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
581 return parDimToWidth(GPUParallelDimAttr::get(getContext(), processor));
582}
583
584LogicalResult ComputeRegionOp::verify() {
585 for (auto op : getLaunchArgs())
586 if (!op.getDefiningOp<acc::ParWidthOp>())
587 return emitOpError(
588 "launch arguments must be results of acc.par_width operations");
589
590 unsigned expectedBlockArgs = getLaunchArgs().size() + getInputArgs().size();
591 unsigned actualBlockArgs = getRegion().front().getNumArguments();
592 if (expectedBlockArgs != actualBlockArgs)
593 return emitOpError("expected ")
594 << expectedBlockArgs << " block arguments (launch + input), got "
595 << actualBlockArgs;
596
597 return success();
598}
599
600void ComputeRegionOp::print(OpAsmPrinter &p) {
601 ValueRange regionArgs = getBody()->getArguments();
602 ValueRange launchArgs = getLaunchArgs();
603 ValueRange inputArgs = getInputArgs();
604
605 assert(regionArgs.size() == (launchArgs.size() + inputArgs.size()) &&
606 "region args mismatch");
607
608 if (getStream())
609 p << " stream(" << getStream() << " : " << getStream().getType() << ")";
610
611 size_t i = 0;
612 if (!launchArgs.empty()) {
613 p << " launch(";
614 for (size_t j = 0; j < launchArgs.size(); ++j, ++i) {
615 p << regionArgs[i] << " = " << launchArgs[j];
616 if (j < launchArgs.size() - 1)
617 p << ", ";
618 }
619 p << ")";
620 }
621 if (!inputArgs.empty()) {
622 p << " ins(";
623 for (size_t j = 0; j < inputArgs.size(); ++j, ++i) {
624 p << regionArgs[i] << " = " << inputArgs[j];
625 if (j < inputArgs.size() - 1)
626 p << ", ";
627 }
628 p << ") : (";
629 for (size_t j = 0; j < inputArgs.size(); ++j) {
630 p << inputArgs[j].getType();
631 if (j < inputArgs.size() - 1)
632 p << ", ";
633 }
634 p << ")";
635 }
636 p.printOptionalArrowTypeList(getResultTypes());
637 p << " ";
638 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
639 p.printOptionalAttrDict((*this)->getAttrs(),
640 /*elidedAttrs=*/getOperandSegmentSizeAttr());
641}
642
643ParseResult ComputeRegionOp::parse(OpAsmParser &parser,
645 auto &builder = parser.getBuilder();
646
648 OpAsmParser::UnresolvedOperand streamOperand;
649 Type streamType;
652 SmallVector<Type> types;
653
654 bool hasStream = false;
655 if (succeeded(parser.parseOptionalKeyword("stream"))) {
656 hasStream = true;
657 if (parser.parseLParen() || parser.parseOperand(streamOperand) ||
658 parser.parseColon() || parser.parseType(streamType) ||
659 parser.parseRParen())
660 return failure();
661 }
662
663 if (succeeded(parser.parseOptionalKeyword("launch"))) {
664 if (parser.parseAssignmentList(regionArgs, launchOperands))
665 return failure();
666 Type indexType = builder.getIndexType();
667 for (size_t i = 0; i < regionArgs.size(); ++i)
668 types.push_back(indexType);
669 }
670
671 if (succeeded(parser.parseOptionalKeyword("ins"))) {
672 if (parser.parseAssignmentList(regionArgs, inputOperands) ||
673 parser.parseColon() || parser.parseLParen() ||
674 parser.parseTypeList(types) || parser.parseRParen())
675 return failure();
676 }
677
678 if (parser.parseOptionalArrowTypeList(result.types))
679 return failure();
680
681 for (auto [iterArg, type] : llvm::zip_equal(regionArgs, types))
682 iterArg.type = type;
683
684 Region *body = result.addRegion();
685 if (parser.parseRegion(*body, regionArgs))
686 return failure();
687
688 const size_t numLaunchOperands = launchOperands.size();
689 const size_t numInputOperands = inputOperands.size();
690 assert(numLaunchOperands + numInputOperands == regionArgs.size() &&
691 "compute region args mismatch");
692
693 result.addAttribute(
694 ComputeRegionOp::getOperandSegmentSizeAttr(),
695 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLaunchOperands),
696 static_cast<int32_t>(numInputOperands),
697 hasStream ? 1 : 0}));
698
699 for (size_t i = 0; i < numLaunchOperands; ++i) {
700 if (parser.resolveOperand(launchOperands[i], types[i], result.operands))
701 return failure();
702 }
703
704 for (size_t i = numLaunchOperands; i < regionArgs.size(); ++i) {
705 if (parser.resolveOperand(inputOperands[i - numLaunchOperands], types[i],
706 result.operands))
707 return failure();
708 }
709
710 if (hasStream) {
711 if (parser.resolveOperand(streamOperand, streamType, result.operands))
712 return failure();
713 }
714
715 if (parser.parseOptionalAttrDict(result.attributes))
716 return failure();
717
718 return success();
719}
720
721//===----------------------------------------------------------------------===//
722// GPUParallelDimAttr
723//===----------------------------------------------------------------------===//
724
725GPUParallelDimAttr GPUParallelDimAttr::get(MLIRContext *context,
726 gpu::Processor proc) {
727 return processorParDim(context, proc);
728}
729
730GPUParallelDimAttr GPUParallelDimAttr::seqDim(MLIRContext *context) {
731 return processorParDim(context, gpu::Processor::Sequential);
732}
733
734GPUParallelDimAttr GPUParallelDimAttr::threadXDim(MLIRContext *context) {
735 return processorParDim(context, gpu::Processor::ThreadX);
736}
737
738GPUParallelDimAttr GPUParallelDimAttr::threadYDim(MLIRContext *context) {
739 return processorParDim(context, gpu::Processor::ThreadY);
740}
741
742GPUParallelDimAttr GPUParallelDimAttr::threadZDim(MLIRContext *context) {
743 return processorParDim(context, gpu::Processor::ThreadZ);
744}
745
746GPUParallelDimAttr GPUParallelDimAttr::blockXDim(MLIRContext *context) {
747 return processorParDim(context, gpu::Processor::BlockX);
748}
749
750GPUParallelDimAttr GPUParallelDimAttr::blockYDim(MLIRContext *context) {
751 return processorParDim(context, gpu::Processor::BlockY);
752}
753
754GPUParallelDimAttr GPUParallelDimAttr::blockZDim(MLIRContext *context) {
755 return processorParDim(context, gpu::Processor::BlockZ);
756}
757
758Attribute GPUParallelDimAttr::parse(AsmParser &parser, Type type) {
759 GPUParallelDimAttr dim;
760 if (parser.parseLess() || parseProcessorValue(parser, dim) ||
761 parser.parseGreater()) {
762 parser.emitError(parser.getCurrentLocation(),
763 "expected format `<` processor_name `>`");
764 return {};
765 }
766 return dim;
767}
768
769void GPUParallelDimAttr::print(AsmPrinter &printer) const {
770 printer << "<";
771 printProcessorValue(printer, *this);
772 printer << ">";
773}
774
775GPUParallelDimAttr GPUParallelDimAttr::threadDim(MLIRContext *context,
776 unsigned index) {
777 assert(index <= 2 && "thread dimension index must be 0, 1, or 2");
778 switch (index) {
779 case 0:
780 return threadXDim(context);
781 case 1:
782 return threadYDim(context);
783 case 2:
784 return threadZDim(context);
785 }
786 llvm_unreachable("validated thread dimension index");
787}
788
789GPUParallelDimAttr GPUParallelDimAttr::blockDim(MLIRContext *context,
790 unsigned index) {
791 assert(index <= 2 && "block dimension index must be 0, 1, or 2");
792 switch (index) {
793 case 0:
794 return blockXDim(context);
795 case 1:
796 return blockYDim(context);
797 case 2:
798 return blockZDim(context);
799 }
800 llvm_unreachable("validated block dimension index");
801}
802
803gpu::Processor GPUParallelDimAttr::getProcessor() const {
804 return indexToGpuProcessor(getValue().getInt());
805}
806
807int GPUParallelDimAttr::getOrder() const {
808 return gpuProcessorIndex(getProcessor());
809}
810
811GPUParallelDimAttr GPUParallelDimAttr::getOneHigher() const {
812 int order = getOrder();
813 if (order >= 6) // BlockZ is the highest
814 return *this;
815 return get(getContext(), indexToGpuProcessor(order + 1));
816}
817
818GPUParallelDimAttr GPUParallelDimAttr::getOneLower() const {
819 int order = getOrder();
820 if (order <= 0) // Sequential is the lowest
821 return *this;
822 return get(getContext(), indexToGpuProcessor(order - 1));
823}
824
825bool GPUParallelDimAttr::isSeq() const {
826 return getProcessor() == gpu::Processor::Sequential;
827}
828bool GPUParallelDimAttr::isThreadX() const {
829 return getProcessor() == gpu::Processor::ThreadX;
830}
831bool GPUParallelDimAttr::isThreadY() const {
832 return getProcessor() == gpu::Processor::ThreadY;
833}
834bool GPUParallelDimAttr::isThreadZ() const {
835 return getProcessor() == gpu::Processor::ThreadZ;
836}
837bool GPUParallelDimAttr::isBlockX() const {
838 return getProcessor() == gpu::Processor::BlockX;
839}
840bool GPUParallelDimAttr::isBlockY() const {
841 return getProcessor() == gpu::Processor::BlockY;
842}
843bool GPUParallelDimAttr::isBlockZ() const {
844 return getProcessor() == gpu::Processor::BlockZ;
845}
846bool GPUParallelDimAttr::isAnyThread() const {
847 return isThreadX() || isThreadY() || isThreadZ();
848}
849bool GPUParallelDimAttr::isAnyBlock() const {
850 return isBlockX() || isBlockY() || isBlockZ();
851}
852
853//===----------------------------------------------------------------------===//
854// GPUParallelDimsAttr
855//===----------------------------------------------------------------------===//
856
857GPUParallelDimsAttr GPUParallelDimsAttr::seq(MLIRContext *ctx) {
858 return GPUParallelDimsAttr::get(ctx, {GPUParallelDimAttr::seqDim(ctx)});
859}
860
861bool GPUParallelDimsAttr::isSeq() const {
862 assert(!getArray().empty() && "no par_dims found");
863 if (getArray().size() == 1) {
864 auto parDim = dyn_cast<GPUParallelDimAttr>(getArray()[0]);
865 assert(parDim && "expected GPUParallelDimAttr");
866 return parDim.isSeq();
867 }
868 return false;
869}
870
871bool GPUParallelDimsAttr::isParallel() const { return !isSeq(); }
872
873bool GPUParallelDimsAttr::isMultiDim() const { return getArray().size() > 1; }
874
875bool GPUParallelDimsAttr::hasAnyBlockLevel() const {
876 return llvm::any_of(
877 getArray(), [](const GPUParallelDimAttr &p) { return p.isAnyBlock(); });
878}
879
880bool GPUParallelDimsAttr::hasOnlyBlockLevel() const {
881 return !getArray().empty() &&
882 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
883 return p.isAnyBlock();
884 });
885}
886
887bool GPUParallelDimsAttr::hasOnlyThreadYLevel() const {
888 return !getArray().empty() &&
889 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
890 return p.isThreadY();
891 });
892}
893
894bool GPUParallelDimsAttr::hasOnlyThreadXLevel() const {
895 return !getArray().empty() &&
896 llvm::all_of(getArray(), [](const GPUParallelDimAttr &p) {
897 return p.isThreadX();
898 });
899}
900
901Attribute GPUParallelDimsAttr::parse(AsmParser &parser, Type type) {
902 auto delimiter = AsmParser::Delimiter::Square;
904 auto parseParDim = [&]() -> ParseResult {
905 GPUParallelDimAttr dim;
906 if (parseProcessorValue(parser, dim))
907 return failure();
908 parDims.push_back(dim);
909 return success();
910 };
911 if (parser.parseCommaSeparatedList(delimiter, parseParDim,
912 "list of OpenACC GPU parallel dimensions"))
913 return {};
914 return GPUParallelDimsAttr::get(parser.getContext(), parDims);
915}
916
917void GPUParallelDimsAttr::print(AsmPrinter &printer) const {
918 printer << "[";
919 llvm::interleaveComma(getArray(), printer,
920 [&printer](const GPUParallelDimAttr &p) {
921 printProcessorValue(printer, p);
922 });
923 printer << "]";
924}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1295
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1305
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:493
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:504
b getContext())
static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op, GPUParallelDimAttr parDim)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:315
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:198
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IndexType getIndexType()
Definition Builders.cpp:55
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:441
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5174
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5143
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5247
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5151
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region)
Replace all uses of orig within the given region with replacement.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
Definition RegionUtils.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.