27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
55struct RemoveEmptyKernelEnvironment
57 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
59 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
60 PatternRewriter &rewriter)
const override {
61 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
63 Block &block = op.getRegion().front();
69 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
71 op, op.getWaitOperands(), Value(),
72 op.getWaitDevnum(),
nullptr, Value());
80static void updateComputeRegionInputOperandSegments(ComputeRegionOp op,
83 const size_t numLaunch = op.getLaunchArgs().size();
84 op->setAttr(ComputeRegionOp::getOperandSegmentSizeAttr(),
86 static_cast<int32_t>(numInput),
87 op.getStream() ? 1 : 0}));
90struct ComputeRegionRemoveDuplicateArgs
94 LogicalResult matchAndRewrite(ComputeRegionOp op,
95 PatternRewriter &rewriter)
const override {
96 Block *body = op.getBody();
97 const size_t numLaunch = op.getLaunchArgs().size();
98 size_t numInput = op.getInputArgs().size();
100 "region args mismatch");
102 bool mergedAny =
false;
105 for (
size_t j = 1; j < numInput && !merged; ++j) {
106 for (
size_t i = 0; i < j; ++i) {
107 if (op->getOperand(
static_cast<unsigned>(numLaunch + i)) !=
108 op->getOperand(
static_cast<unsigned>(numLaunch + j)))
110 unsigned keepIdx =
static_cast<unsigned>(numLaunch + i);
111 unsigned dropIdx =
static_cast<unsigned>(numLaunch + j);
115 op->eraseOperand(dropIdx);
128 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
133struct ComputeRegionRemoveUnusedArgs
137 LogicalResult matchAndRewrite(ComputeRegionOp op,
138 PatternRewriter &rewriter)
const override {
139 Block *body = op.getBody();
140 const size_t numLaunch = op.getLaunchArgs().size();
141 size_t numInput = op.getInputArgs().size();
143 "region args mismatch");
145 bool changed =
false;
146 for (
size_t k = numLaunch; k < numLaunch + numInput;) {
152 op->eraseOperand(
static_cast<unsigned>(k));
159 updateComputeRegionInputOperandSegments(op, rewriter, numInput);
164template <
typename EffectTy>
169 for (
unsigned i = 0, e = operand.
size(); i < e; ++i)
170 effects.emplace_back(EffectTy::get(), &operand[i]);
173template <
typename EffectTy>
178 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(
result));
181static int64_t gpuProcessorIndex(gpu::Processor p) {
183 case gpu::Processor::Sequential:
185 case gpu::Processor::ThreadX:
187 case gpu::Processor::ThreadY:
189 case gpu::Processor::ThreadZ:
191 case gpu::Processor::BlockX:
193 case gpu::Processor::BlockY:
195 case gpu::Processor::BlockZ:
198 llvm_unreachable(
"unhandled gpu::Processor");
201static gpu::Processor indexToGpuProcessor(
int64_t idx) {
204 return gpu::Processor::Sequential;
206 return gpu::Processor::ThreadX;
208 return gpu::Processor::ThreadY;
210 return gpu::Processor::ThreadZ;
212 return gpu::Processor::BlockX;
214 return gpu::Processor::BlockY;
216 return gpu::Processor::BlockZ;
218 return gpu::Processor::Sequential;
223 return GPUParallelDimAttr::get(
224 context, IntegerAttr::get(IndexType::get(context), dimInt));
227static GPUParallelDimAttr processorParDim(
MLIRContext *context,
228 gpu::Processor proc) {
229 return GPUParallelDimAttr::get(
231 IntegerAttr::get(IndexType::get(context), gpuProcessorIndex(proc)));
234static ParseResult parseProcessorValue(
AsmParser &parser,
235 GPUParallelDimAttr &dim) {
240 auto maybeProcessor = gpu::symbolizeProcessor(keyword);
243 <<
"expected one of ::mlir::gpu::Processor enum names";
244 dim = intToParDim(parser.
getContext(), gpuProcessorIndex(*maybeProcessor));
248static void printProcessorValue(
AsmPrinter &printer,
249 const GPUParallelDimAttr &attr) {
250 gpu::Processor processor = indexToGpuProcessor(attr.getValue().getInt());
251 printer << gpu::stringifyProcessor(processor);
260void KernelEnvironmentOp::getSuccessorRegions(
270void KernelEnvironmentOp::getCanonicalizationPatterns(
272 results.
add<RemoveEmptyKernelEnvironment>(context);
276template <
typename ComputeConstructT>
280 std::optional<Value> &asyncOperand, UnitAttr &asyncOnly) {
281 if (computeConstruct.hasAsyncOnly(clauseDeviceType)) {
282 asyncOnly = UnitAttr::get(context);
285 if (
Value asyncValue = computeConstruct.getAsyncValue(clauseDeviceType)) {
286 asyncOperand = asyncValue;
293template <
typename ComputeConstructT>
296 std::optional<Value> &waitDevnum,
298 UnitAttr &waitOnly) {
299 if (computeConstruct.hasWaitOnly(clauseDeviceType)) {
300 waitOnly = UnitAttr::get(context);
303 Value devnum = computeConstruct.getWaitDevnum(clauseDeviceType);
304 auto waitValues = computeConstruct.getWaitValues(clauseDeviceType);
305 if (!devnum && waitValues.empty())
309 waitOperands.append(waitValues.begin(), waitValues.end());
313template <
typename ComputeConstructT>
315 ComputeConstructT computeConstruct, DeviceType deviceType,
316 std::optional<Value> &asyncOperand, UnitAttr &asyncOnly,
318 UnitAttr &waitOnly) {
319 MLIRContext *context = computeConstruct->getContext();
324 if (deviceType != DeviceType::None)
326 asyncOperand, asyncOnly);
330 waitOperands, waitOnly)) {
331 if (deviceType != DeviceType::None)
333 waitOperands, waitOnly);
337template <
typename ComputeConstructT>
339KernelEnvironmentOp::createAndPopulate(ComputeConstructT computeConstruct,
340 DeviceType deviceType,
342 std::optional<Value> asyncOperand;
343 UnitAttr asyncOnly =
nullptr;
344 std::optional<Value> waitDevnum;
346 UnitAttr waitOnly =
nullptr;
348 asyncOnly, waitDevnum, waitOperands,
351 auto kernelEnvironment = KernelEnvironmentOp::create(
352 builder, computeConstruct->getLoc(),
353 computeConstruct.getDataClauseOperands(), asyncOperand.value_or(
Value()),
354 asyncOnly, waitDevnum.value_or(
Value()), waitOperands, waitOnly);
355 Block &block = kernelEnvironment.getRegion().emplaceBlock();
357 return kernelEnvironment;
360template KernelEnvironmentOp
361KernelEnvironmentOp::createAndPopulate<ParallelOp>(ParallelOp, DeviceType,
363template KernelEnvironmentOp
364KernelEnvironmentOp::createAndPopulate<KernelsOp>(KernelsOp, DeviceType,
366template KernelEnvironmentOp
367KernelEnvironmentOp::createAndPopulate<SerialOp>(SerialOp, DeviceType,
370LogicalResult KernelEnvironmentOp::verify() {
372 return emitError(
"async-only cannot appear with async operand");
373 if (getWaitOnly() && (!getWaitOperands().empty() || getWaitDevnum()))
374 return emitError(
"wait-only cannot appear with wait operands or devnum");
382LogicalResult FirstprivateMapInitialOp::verify() {
384 return emitError(
"data clause associated with firstprivate operation must "
387 return emitError(
"must have var operand");
388 if (!mlir::isa<mlir::acc::PointerLikeType>(
getVar().
getType()) &&
390 return emitError(
"var must be mappable or pointer-like");
391 if (mlir::isa<mlir::acc::PointerLikeType>(
getVar().
getType()) &&
393 return emitError(
"varType must capture the element type of var");
394 if (getModifiers() != acc::DataClauseModifier::none)
395 return emitError(
"no data clause modifiers are allowed");
399void FirstprivateMapInitialOp::getEffects(
412void ReductionInitOp::getSuccessorRegions(
418void ReductionInitOp::getRegionInvocationBounds(
421 invocationBounds.emplace_back(1, 1);
428LogicalResult ReductionInitOp::verify() {
430 if (
auto yieldOp = dyn_cast<acc::YieldOp>(block.
getTerminator())) {
431 if (yieldOp.getNumOperands() != 1)
433 "region must yield exactly one value (private storage)");
435 return emitOpError(
"yielded value type must match var type");
444void ReductionCombineRegionOp::getSuccessorRegions(
450void ReductionCombineRegionOp::getRegionInvocationBounds(
453 invocationBounds.emplace_back(1, 1);
457ReductionCombineRegionOp::getSuccessorInputs(
RegionSuccessor successor) {
461LogicalResult ReductionCombineRegionOp::verify() {
463 if (
auto yieldOp = dyn_cast<acc::YieldOp>(block.
getTerminator())) {
464 if (yieldOp.getNumOperands() != 0)
465 return emitOpError(
"region must be terminated by acc.yield with no "
475LogicalResult ReductionAccumulateOp::verify() {
476 Type valueType = getValue().getType();
477 auto ptrLikeTy = cast<PointerLikeType>(getMemref().
getType());
478 Type elementType = ptrLikeTy.getElementType();
480 return emitOpError(
"pointer-like destination must have an element type");
481 if (elementType != valueType)
482 return emitOpError(
"pointer-like element type must match value type");
483 if (getParDims().getArray().empty())
484 return emitOpError(
"par_dims must specify at least one parallel dimension");
492LogicalResult ReductionAccumulateArrayOp::verify() {
493 if (getParDims().getArray().empty())
494 return emitOpError(
"par_dims must specify at least one parallel dimension");
502void ReductionCombineOp::getEffects(
518 GPUParallelDimAttr parDim) {
519 for (
auto launchArg : op.getLaunchArgs()) {
520 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
523 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
524 if (launchArgDim == parDim)
530std::optional<Value> ComputeRegionOp::getLaunchArg(GPUParallelDimAttr parDim) {
532 return parWidthOp.getResult();
537ComputeRegionOp::getKnownLaunchArg(GPUParallelDimAttr parDim) {
539 if (parWidthOp.getLaunchArg())
540 return parWidthOp.getLaunchArg();
544std::optional<uint64_t>
545ComputeRegionOp::getKnownConstantLaunchArg(GPUParallelDimAttr parDim) {
546 auto knownParWidth = getKnownLaunchArg(parDim);
547 if (knownParWidth.has_value())
553 getInputArgsMutable().append(value);
554 return getBody()->addArgument(value.
getType(), getLoc());
557std::optional<BlockArgument>
558ComputeRegionOp::wireHoistedValueThroughIns(
Value value) {
559 Region ®ion = getRegion();
561 auto useIsInRegion = [&](
OpOperand &use) ->
bool {
562 return region.
isAncestor(use.getOwner()->getParentRegion());
566 !llvm::any_of(value.
getUses(), useIsInRegion))
574bool ComputeRegionOp::isEffectivelySerial() {
577 if (getLaunchArg(GPUParallelDimAttr::seqDim(ctx)))
580 auto checkDim = [&](GPUParallelDimAttr dim) ->
bool {
581 auto val = getKnownConstantLaunchArg(dim);
582 return val && *val == 1;
585 return checkDim(GPUParallelDimAttr::threadXDim(ctx)) &&
586 checkDim(GPUParallelDimAttr::threadYDim(ctx)) &&
587 checkDim(GPUParallelDimAttr::threadZDim(ctx)) &&
588 checkDim(GPUParallelDimAttr::blockXDim(ctx)) &&
589 checkDim(GPUParallelDimAttr::blockYDim(ctx)) &&
590 checkDim(GPUParallelDimAttr::blockZDim(ctx));
593BlockArgument ComputeRegionOp::parDimToWidth(GPUParallelDimAttr parDim) {
594 for (
auto [pos, launchArg] : llvm::enumerate(getLaunchArgs())) {
595 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
597 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
598 if (launchArgDim == parDim) {
599 assert(pos < getRegion().front().getNumArguments() &&
600 "launch arg position out of range");
601 return getRegion().front().getArgument(pos);
604 llvm_unreachable(
"attempting to get unspecified parDim");
609 for (
auto launchArg : getLaunchArgs()) {
610 auto parOp = launchArg.getDefiningOp<ParWidthOp>();
611 auto launchArgDim = cast<GPUParallelDimAttr>(parOp.getParDim());
612 int64_t dimInt = launchArgDim.getValue().getInt();
613 parDims.push_back(intToParDim(
getContext(), dimInt));
619 Block *body = getBody();
623 unsigned numLaunchArgs = getLaunchArgs().size();
624 unsigned numInputArgs = getInputArgs().size();
625 if (argNumber >= numLaunchArgs + numInputArgs)
627 if (argNumber < numLaunchArgs)
628 return getLaunchArgs()[argNumber];
629 return getInputArgs()[argNumber - numLaunchArgs];
632std::optional<BlockArgument> ComputeRegionOp::getBlockArg(
Value value) {
633 Block *body = getBody();
634 for (
auto [idx, launchVal] : llvm::enumerate(getLaunchArgs())) {
635 if (launchVal == value)
638 unsigned numLaunch = getLaunchArgs().size();
639 for (
auto [idx, inputVal] : llvm::enumerate(getInputArgs())) {
640 if (inputVal == value)
648 results.
add<ComputeRegionRemoveDuplicateArgs, ComputeRegionRemoveUnusedArgs>(
652BlockArgument ComputeRegionOp::gpuParWidth(gpu::Processor processor) {
653 return parDimToWidth(GPUParallelDimAttr::get(
getContext(), processor));
656LogicalResult ComputeRegionOp::verify() {
657 for (
auto op : getLaunchArgs())
658 if (!op.getDefiningOp<acc::ParWidthOp>())
660 "launch arguments must be results of acc.par_width operations");
662 unsigned expectedBlockArgs = getLaunchArgs().size() + getInputArgs().size();
663 unsigned actualBlockArgs = getRegion().front().getNumArguments();
664 if (expectedBlockArgs != actualBlockArgs)
666 << expectedBlockArgs <<
" block arguments (launch + input), got "
673 ValueRange regionArgs = getBody()->getArguments();
677 assert(regionArgs.size() == (launchArgs.size() + inputArgs.size()) &&
678 "region args mismatch");
681 p <<
" stream(" << getStream() <<
" : " << getStream().getType() <<
")";
684 if (!launchArgs.empty()) {
686 for (
size_t j = 0;
j < launchArgs.size(); ++
j, ++i) {
687 p << regionArgs[i] <<
" = " << launchArgs[
j];
688 if (
j < launchArgs.size() - 1)
693 if (!inputArgs.empty()) {
695 for (
size_t j = 0;
j < inputArgs.size(); ++
j, ++i) {
696 p << regionArgs[i] <<
" = " << inputArgs[
j];
697 if (
j < inputArgs.size() - 1)
701 for (
size_t j = 0;
j < inputArgs.size(); ++
j) {
702 p << inputArgs[
j].getType();
703 if (
j < inputArgs.size() - 1)
712 getOperandSegmentSizeAttr());
715ParseResult ComputeRegionOp::parse(
OpAsmParser &parser,
726 bool hasStream =
false;
739 for (
size_t i = 0; i < regionArgs.size(); ++i)
740 types.push_back(indexType);
753 for (
auto [iterArg, type] : llvm::zip_equal(regionArgs, types))
759 ComputeRegionOp::ensureTerminator(*body, parser.
getBuilder(),
762 const size_t numLaunchOperands = launchOperands.size();
763 const size_t numInputOperands = inputOperands.size();
764 assert(numLaunchOperands + numInputOperands == regionArgs.size() &&
765 "compute region args mismatch");
768 ComputeRegionOp::getOperandSegmentSizeAttr(),
770 static_cast<int32_t>(numInputOperands),
771 hasStream ? 1 : 0}));
773 for (
size_t i = 0; i < numLaunchOperands; ++i) {
778 for (
size_t i = numLaunchOperands; i < regionArgs.size(); ++i) {
779 if (parser.
resolveOperand(inputOperands[i - numLaunchOperands], types[i],
799LogicalResult PredicateRegionOp::verify() {
800 if (getRegion().empty())
801 return emitOpError(
"region needs to have at least one block");
802 if (getRegion().front().getNumArguments() > 0)
803 return emitOpError(
"region cannot have any arguments");
804 if (!getOperation()->getParentOfType<ComputeRegionOp>())
805 return emitOpError(
"must be nested within an acc.compute_region operation");
813GPUParallelDimAttr GPUParallelDimAttr::get(
MLIRContext *context,
814 gpu::Processor proc) {
815 return processorParDim(context, proc);
818GPUParallelDimAttr GPUParallelDimAttr::seqDim(
MLIRContext *context) {
819 return processorParDim(context, gpu::Processor::Sequential);
822GPUParallelDimAttr GPUParallelDimAttr::threadXDim(
MLIRContext *context) {
823 return processorParDim(context, gpu::Processor::ThreadX);
826GPUParallelDimAttr GPUParallelDimAttr::threadYDim(
MLIRContext *context) {
827 return processorParDim(context, gpu::Processor::ThreadY);
830GPUParallelDimAttr GPUParallelDimAttr::threadZDim(
MLIRContext *context) {
831 return processorParDim(context, gpu::Processor::ThreadZ);
834GPUParallelDimAttr GPUParallelDimAttr::blockXDim(
MLIRContext *context) {
835 return processorParDim(context, gpu::Processor::BlockX);
838GPUParallelDimAttr GPUParallelDimAttr::blockYDim(
MLIRContext *context) {
839 return processorParDim(context, gpu::Processor::BlockY);
842GPUParallelDimAttr GPUParallelDimAttr::blockZDim(
MLIRContext *context) {
843 return processorParDim(context, gpu::Processor::BlockZ);
847 GPUParallelDimAttr dim;
848 if (parser.
parseLess() || parseProcessorValue(parser, dim) ||
851 "expected format `<` processor_name `>`");
857void GPUParallelDimAttr::print(
AsmPrinter &printer)
const {
859 printProcessorValue(printer, *
this);
863GPUParallelDimAttr GPUParallelDimAttr::threadDim(
MLIRContext *context,
865 assert(
index <= 2 &&
"thread dimension index must be 0, 1, or 2");
868 return threadXDim(context);
870 return threadYDim(context);
872 return threadZDim(context);
874 llvm_unreachable(
"validated thread dimension index");
877GPUParallelDimAttr GPUParallelDimAttr::blockDim(
MLIRContext *context,
879 assert(
index <= 2 &&
"block dimension index must be 0, 1, or 2");
882 return blockXDim(context);
884 return blockYDim(context);
886 return blockZDim(context);
888 llvm_unreachable(
"validated block dimension index");
891gpu::Processor GPUParallelDimAttr::getProcessor()
const {
892 return indexToGpuProcessor(getValue().getInt());
895int GPUParallelDimAttr::getOrder()
const {
896 return gpuProcessorIndex(getProcessor());
899GPUParallelDimAttr GPUParallelDimAttr::getOneHigher()
const {
900 int order = getOrder();
906GPUParallelDimAttr GPUParallelDimAttr::getOneLower()
const {
907 int order = getOrder();
913bool GPUParallelDimAttr::isSeq()
const {
914 return getProcessor() == gpu::Processor::Sequential;
916bool GPUParallelDimAttr::isThreadX()
const {
917 return getProcessor() == gpu::Processor::ThreadX;
919bool GPUParallelDimAttr::isThreadY()
const {
920 return getProcessor() == gpu::Processor::ThreadY;
922bool GPUParallelDimAttr::isThreadZ()
const {
923 return getProcessor() == gpu::Processor::ThreadZ;
925bool GPUParallelDimAttr::isBlockX()
const {
926 return getProcessor() == gpu::Processor::BlockX;
928bool GPUParallelDimAttr::isBlockY()
const {
929 return getProcessor() == gpu::Processor::BlockY;
931bool GPUParallelDimAttr::isBlockZ()
const {
932 return getProcessor() == gpu::Processor::BlockZ;
934bool GPUParallelDimAttr::isAnyThread()
const {
935 return isThreadX() || isThreadY() || isThreadZ();
937bool GPUParallelDimAttr::isAnyBlock()
const {
938 return isBlockX() || isBlockY() || isBlockZ();
945GPUParallelDimsAttr GPUParallelDimsAttr::seq(
MLIRContext *ctx) {
946 return GPUParallelDimsAttr::get(ctx, {GPUParallelDimAttr::seqDim(ctx)});
949bool GPUParallelDimsAttr::isSeq()
const {
950 assert(!getArray().empty() &&
"no par_dims found");
951 if (getArray().size() == 1) {
952 auto parDim = dyn_cast<GPUParallelDimAttr>(getArray()[0]);
953 assert(parDim &&
"expected GPUParallelDimAttr");
954 return parDim.isSeq();
959bool GPUParallelDimsAttr::isParallel()
const {
return !isSeq(); }
961bool GPUParallelDimsAttr::isMultiDim()
const {
return getArray().size() > 1; }
963bool GPUParallelDimsAttr::hasAnyBlockLevel()
const {
965 getArray(), [](
const GPUParallelDimAttr &p) {
return p.isAnyBlock(); });
968bool GPUParallelDimsAttr::hasOnlyBlockLevel()
const {
969 return !getArray().empty() &&
970 llvm::all_of(getArray(), [](
const GPUParallelDimAttr &p) {
971 return p.isAnyBlock();
975bool GPUParallelDimsAttr::hasOnlyThreadYLevel()
const {
976 return !getArray().empty() &&
977 llvm::all_of(getArray(), [](
const GPUParallelDimAttr &p) {
978 return p.isThreadY();
982bool GPUParallelDimsAttr::hasOnlyThreadXLevel()
const {
983 return !getArray().empty() &&
984 llvm::all_of(getArray(), [](
const GPUParallelDimAttr &p) {
985 return p.isThreadX();
992 auto parseParDim = [&]() -> ParseResult {
993 GPUParallelDimAttr dim;
994 if (parseProcessorValue(parser, dim))
996 parDims.push_back(dim);
1000 "list of OpenACC GPU parallel dimensions"))
1002 return GPUParallelDimsAttr::get(parser.
getContext(), parDims);
1005void GPUParallelDimsAttr::print(
AsmPrinter &printer)
const {
1007 llvm::interleaveComma(getArray(), printer,
1008 [&printer](
const GPUParallelDimAttr &p) {
1009 printProcessorValue(printer, p);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then continue after the...
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParWidthOp getParWidthOpForLaunchArg(ComputeRegionOp op, GPUParallelDimAttr parDim)
static bool extractWaitClause(ComputeConstructT computeConstruct, DeviceType clauseDeviceType, MLIRContext *context, std::optional< Value > &waitDevnum, SmallVectorImpl< Value > &waitOperands, UnitAttr &waitOnly)
Extract wait for clauseDeviceType. Returns true if a clause was found.
static bool extractAsyncClause(ComputeConstructT computeConstruct, DeviceType clauseDeviceType, MLIRContext *context, std::optional< Value > &asyncOperand, UnitAttr &asyncOnly)
Extract async for clauseDeviceType. Returns true if a clause was found.
static void populateKernelEnvironmentAsyncWait(ComputeConstructT computeConstruct, DeviceType deviceType, std::optional< Value > &asyncOperand, UnitAttr &asyncOnly, std::optional< Value > &waitDevnum, SmallVectorImpl< Value > &waitOperands, UnitAttr &waitOnly)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
unsigned size() const
Returns the current size of the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
bool isOperation() const
Return true if the successor is an operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static CurrentDeviceIdResource * get()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Include the generated interface declarations.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.