19 #include "llvm/ADT/TypeSwitch.h"
24 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
25 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
30 struct MemRefPointerLikeModel
31 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
34 return llvm::cast<MemRefType>(pointer).getElementType();
38 struct LLVMPointerPointerLikeModel
39 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
40 LLVM::LLVMPointerType> {
49 void OpenACCDialect::initialize() {
52 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
55 #define GET_ATTRDEF_LIST
56 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
59 #define GET_TYPEDEF_LIST
60 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
66 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
67 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
75 auto extent = getExtent();
76 auto upperbound = getUpperbound();
77 if (!extent && !upperbound)
78 return emitError(
"expected extent or upperbound.");
88 "data clause associated with private operation must match its intent");
97 return emitError(
"data clause associated with firstprivate operation must "
107 return emitError(
"data clause associated with reduction operation must "
117 return emitError(
"data clause associated with deviceptr operation must "
128 "data clause associated with present operation must match its intent");
137 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
142 "data clause associated with copyin operation must match its intent"
143 " or specify original clause this operation was decomposed from");
147 bool acc::CopyinOp::isCopyinReadonly() {
148 return getDataClause() == acc::DataClause::acc_copyin_readonly;
161 "data clause associated with create operation must match its intent"
162 " or specify original clause this operation was decomposed from");
166 bool acc::CreateOp::isCreateZero() {
168 return getDataClause() == acc::DataClause::acc_create_zero ||
177 return emitError(
"data clause associated with no_create operation must "
188 "data clause associated with attach operation must match its intent");
197 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
198 return emitError(
"data clause associated with device_resident operation "
199 "must match its intent");
210 "data clause associated with link operation must match its intent");
224 "data clause associated with copyout operation must match its intent"
225 " or specify original clause this operation was decomposed from");
227 return emitError(
"must have both host and device pointers");
231 bool acc::CopyoutOp::isCopyoutZero() {
246 getDataClause() != acc::DataClause::acc_declare_device_resident &&
249 "data clause associated with delete operation must match its intent"
250 " or specify original clause this operation was decomposed from");
252 return emitError(
"must have either host or device pointer");
264 "data clause associated with detach operation must match its intent"
265 " or specify original clause this operation was decomposed from");
267 return emitError(
"must have either host or device pointer");
279 "data clause associated with host operation must match its intent"
280 " or specify original clause this operation was decomposed from");
282 return emitError(
"must have both host and device pointers");
293 "data clause associated with device operation must match its intent"
294 " or specify original clause this operation was decomposed from");
305 "data clause associated with use_device operation must match its intent"
306 " or specify original clause this operation was decomposed from");
318 "data clause associated with cache operation must match its intent"
319 " or specify original clause this operation was decomposed from");
323 template <
typename StructureOp>
325 unsigned nRegions = 1) {
328 for (
unsigned i = 0; i < nRegions; ++i)
329 regions.push_back(state.addRegion());
331 for (
Region *region : regions)
339 return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
346 template <
typename OpTy>
353 Value ifCond = op.getIfCond();
357 IntegerAttr constAttr;
360 if (constAttr.getInt())
373 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
385 template <
typename OpTy>
386 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
392 Value ifCond = op.getIfCond();
396 IntegerAttr constAttr;
399 if (constAttr.getInt())
415 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
417 if (optional && region.
empty())
421 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
425 return op->
emitOpError() <<
"expects " << regionName
428 << regionType <<
" type";
431 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
432 if (yieldOp.getOperands().size() != 1 ||
433 yieldOp.getOperands().getTypes()[0] != type)
434 return op->
emitOpError() <<
"expects " << regionName
436 "yield a value of the "
437 << regionType <<
" type";
445 "privatization",
"init", getType(),
449 *
this, getDestroyRegion(),
"privatization",
"destroy", getType(),
461 "privatization",
"init", getType(),
465 if (getCopyRegion().empty())
466 return emitOpError() <<
"expects non-empty copy region";
471 return emitOpError() <<
"expects copy region with two arguments of the "
472 "privatization type";
474 if (getDestroyRegion().empty())
478 "privatization",
"destroy",
495 if (getCombinerRegion().empty())
496 return emitOpError() <<
"expects non-empty combiner region";
498 Block &reductionBlock = getCombinerRegion().
front();
502 return emitOpError() <<
"expects combiner region with the first two "
503 <<
"arguments of the reduction type";
505 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
506 if (yieldOp.getOperands().size() != 1 ||
507 yieldOp.getOperands().getTypes()[0] != getType())
508 return emitOpError() <<
"expects combiner region to yield a value "
509 "of the reduction type";
525 if (parser.parseAttribute(attributes.emplace_back()) ||
526 parser.parseArrow() ||
527 parser.parseOperand(operands.emplace_back()) ||
528 parser.parseColonType(types.emplace_back()))
542 std::optional<mlir::ArrayAttr> attributes) {
543 for (
unsigned i = 0, e = attributes->size(); i < e; ++i) {
546 p << (*attributes)[i] <<
" -> " << operands[i] <<
" : "
556 template <
typename Op>
560 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
561 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
562 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
563 operand.getDefiningOp()))
565 "expect data entry/exit operation or acc.getdeviceptr "
570 template <
typename Op>
574 llvm::StringRef symbolName,
bool checkOperandType =
true) {
575 if (!operands.empty()) {
576 if (!attributes || attributes->size() != operands.size())
578 <<
"expected as many " << symbolName <<
" symbol reference as "
579 << operandName <<
" operands";
583 <<
"unexpected " << symbolName <<
" symbol reference";
588 for (
auto args : llvm::zip(operands, *attributes)) {
591 if (!set.insert(operand).second)
593 << operandName <<
" operand appears more than once";
596 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
597 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
600 <<
"expected symbol reference " << symbolRef <<
" to point to a "
601 << operandName <<
" declaration";
603 if (checkOperandType && decl.getType() && decl.getType() != varType)
604 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
605 <<
") to be the same type as " << operandName
606 <<
" declaration (" << decl.getType() <<
")";
612 unsigned ParallelOp::getNumDataOperands() {
613 return getReductionOperands().size() + getGangPrivateOperands().size() +
614 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
617 Value ParallelOp::getDataOperand(
unsigned i) {
618 unsigned numOptional = getAsync() ? 1 : 0;
619 numOptional += getNumGangs().size();
620 numOptional += getNumWorkers() ? 1 : 0;
621 numOptional += getVectorLength() ? 1 : 0;
622 numOptional += getIfCond() ? 1 : 0;
623 numOptional += getSelfCond() ? 1 : 0;
624 return getOperand(getWaitOperands().size() + numOptional + i);
628 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
629 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
630 "privatizations",
false)))
632 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
633 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
634 "reductions",
false)))
636 if (getNumGangs().size() > 3)
637 return emitOpError() <<
"num_gangs expects a maximum of 3 values";
638 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
645 unsigned SerialOp::getNumDataOperands() {
646 return getReductionOperands().size() + getGangPrivateOperands().size() +
647 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
650 Value SerialOp::getDataOperand(
unsigned i) {
651 unsigned numOptional = getAsync() ? 1 : 0;
652 numOptional += getIfCond() ? 1 : 0;
653 numOptional += getSelfCond() ? 1 : 0;
654 return getOperand(getWaitOperands().size() + numOptional + i);
658 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
659 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
660 "privatizations",
false)))
662 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
663 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
664 "reductions",
false)))
666 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
673 unsigned KernelsOp::getNumDataOperands() {
674 return getDataClauseOperands().size();
677 Value KernelsOp::getDataOperand(
unsigned i) {
678 unsigned numOptional = getAsync() ? 1 : 0;
679 numOptional += getWaitOperands().size();
680 numOptional += getNumGangs().size();
681 numOptional += getNumWorkers() ? 1 : 0;
682 numOptional += getVectorLength() ? 1 : 0;
683 numOptional += getIfCond() ? 1 : 0;
684 numOptional += getSelfCond() ? 1 : 0;
685 return getOperand(numOptional + i);
689 if (getNumGangs().size() > 3)
690 return emitOpError() <<
"num_gangs expects a maximum of 3 values";
691 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
699 if (getDataClauseOperands().empty())
700 return emitError(
"at least one operand must appear on the host_data "
703 for (
mlir::Value operand : getDataClauseOperands())
704 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
705 return emitError(
"expect data entry operation as defining op");
711 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
720 std::optional<OpAsmParser::UnresolvedOperand> &value,
721 Type &valueType,
bool &needComa,
bool &newValue) {
735 OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &gangNum,
736 Type &gangNumType, std::optional<OpAsmParser::UnresolvedOperand> &gangDim,
738 std::optional<OpAsmParser::UnresolvedOperand> &gangStatic,
739 Type &gangStaticType, UnitAttr &hasGang) {
741 gangNum = std::nullopt;
742 gangDim = std::nullopt;
743 gangStatic = std::nullopt;
744 bool needComa =
false;
749 bool newValue =
false;
750 bool needValue =
false;
759 gangNumType, needComa, newValue)))
762 gangDimType, needComa, newValue)))
765 gangStatic, gangStaticType, needComa,
769 if (!newValue && needValue) {
771 "new value expected after comma");
779 if (!gangNum && !gangDim && !gangStatic) {
781 "expect at least one of num, dim or static values");
793 Value gangStatic,
Type gangStaticType, UnitAttr hasGang) {
794 if (gangNum || gangStatic || gangDim) {
797 p << LoopOp::getGangNumKeyword() <<
"=" << gangNum <<
" : "
799 if (gangStatic || gangDim)
803 p << LoopOp::getGangDimKeyword() <<
"=" << gangDim <<
" : "
809 p << LoopOp::getGangStaticKeyword() <<
"=" << gangStatic <<
" : "
817 std::optional<OpAsmParser::UnresolvedOperand> &workerNum,
818 Type &workerNumType, UnitAttr &hasWorker) {
830 Type workerNumType, UnitAttr hasWorker) {
832 p <<
"(" << workerNum <<
" : " << workerNumType <<
")";
837 std::optional<OpAsmParser::UnresolvedOperand> &vectorLength,
838 Type &vectorLengthType, UnitAttr &hasVector) {
850 Type vectorLengthType, UnitAttr hasVector) {
852 p <<
"(" << vectorLength <<
" : " << vectorLengthType <<
")";
857 if ((getAuto_() && (getIndependent() || getSeq())) ||
858 (getIndependent() && getSeq())) {
859 return emitError() <<
"only one of \"" << acc::LoopOp::getAutoAttrStrName()
860 <<
"\", " << getIndependentAttrName() <<
", "
862 <<
" can be present at the same time";
866 if (getSeq() && (getHasGang() || getHasWorker() || getHasVector()))
867 return emitError(
"gang, worker or vector cannot appear with the seq attr");
869 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
870 *
this, getPrivatizations(), getPrivateOperands(),
"private",
871 "privatizations",
false)))
874 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
875 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
876 "reductions",
false)))
880 if (getRegion().empty())
881 return emitError(
"expected non-empty body.");
886 unsigned LoopOp::getNumDataOperands() {
887 return getReductionOperands().size() + getPrivateOperands().size();
890 Value LoopOp::getDataOperand(
unsigned i) {
891 unsigned numOptional = getGangNum() ? 1 : 0;
892 numOptional += getGangDim() ? 1 : 0;
893 numOptional += getGangStatic() ? 1 : 0;
894 numOptional += getVectorLength() ? 1 : 0;
895 numOptional += getWorkerNum() ? 1 : 0;
896 numOptional += getTileOperands().size();
897 numOptional += getCacheOperands().size();
898 return getOperand(numOptional + i);
909 if (getOperands().empty() && !getDefaultAttr())
910 return emitError(
"at least one operand or the default attribute "
911 "must appear on the data operation");
913 for (
mlir::Value operand : getDataClauseOperands())
914 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
915 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
916 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
917 operand.getDefiningOp()))
918 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
924 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
926 Value DataOp::getDataOperand(
unsigned i) {
927 unsigned numOptional = getIfCond() ? 1 : 0;
928 numOptional += getAsync() ? 1 : 0;
929 numOptional += getWaitOperands().size();
930 return getOperand(numOptional + i);
941 if (getDataClauseOperands().empty())
942 return emitError(
"at least one operand must be present in dataOperands on "
943 "the exit data operation");
947 if (getAsyncOperand() && getAsync())
948 return emitError(
"async attribute cannot appear with asyncOperand");
952 if (!getWaitOperands().empty() && getWait())
953 return emitError(
"wait attribute cannot appear with waitOperands");
955 if (getWaitDevnum() && getWaitOperands().empty())
956 return emitError(
"wait_devnum cannot appear without waitOperands");
961 unsigned ExitDataOp::getNumDataOperands() {
962 return getDataClauseOperands().size();
965 Value ExitDataOp::getDataOperand(
unsigned i) {
966 unsigned numOptional = getIfCond() ? 1 : 0;
967 numOptional += getAsyncOperand() ? 1 : 0;
968 numOptional += getWaitDevnum() ? 1 : 0;
969 return getOperand(getWaitOperands().size() + numOptional + i);
974 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
985 if (getDataClauseOperands().empty())
986 return emitError(
"at least one operand must be present in dataOperands on "
987 "the enter data operation");
991 if (getAsyncOperand() && getAsync())
992 return emitError(
"async attribute cannot appear with asyncOperand");
996 if (!getWaitOperands().empty() && getWait())
997 return emitError(
"wait attribute cannot appear with waitOperands");
999 if (getWaitDevnum() && getWaitOperands().empty())
1000 return emitError(
"wait_devnum cannot appear without waitOperands");
1002 for (
mlir::Value operand : getDataClauseOperands())
1003 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
1004 operand.getDefiningOp()))
1005 return emitError(
"expect data entry operation as defining op");
1010 unsigned EnterDataOp::getNumDataOperands() {
1011 return getDataClauseOperands().size();
1014 Value EnterDataOp::getDataOperand(
unsigned i) {
1015 unsigned numOptional = getIfCond() ? 1 : 0;
1016 numOptional += getAsyncOperand() ? 1 : 0;
1017 numOptional += getWaitDevnum() ? 1 : 0;
1018 return getOperand(getWaitOperands().size() + numOptional + i);
1023 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
1042 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
1049 if (
Value writeVal = op.getWriteOpVal()) {
1059 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
1065 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
1066 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
1068 return dyn_cast<AtomicReadOp>(getSecondOp());
1071 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
1072 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
1074 return dyn_cast<AtomicWriteOp>(getSecondOp());
1077 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
1078 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
1080 return dyn_cast<AtomicUpdateOp>(getSecondOp());
1083 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
1089 template <
typename Op>
1092 bool requireAtLeastOneOperand =
true) {
1093 if (operands.empty() && requireAtLeastOneOperand)
1096 "at least one operand must appear on the declare operation");
1099 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1100 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
1101 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
1102 operand.getDefiningOp()))
1104 "expect valid declare data entry operation or acc.getdeviceptr "
1108 assert(varPtr &&
"declare operands can only be data entry operations which "
1109 "must have varPtr");
1110 std::optional<mlir::acc::DataClause> dataClauseOptional{
1112 assert(dataClauseOptional.has_value() &&
1113 "declare operands can only be data entry operations which must have "
1117 if (!varPtr.getDefiningOp())
1121 auto declareAttribute{
1123 if (!declareAttribute)
1125 "expect declare attribute on variable in declare operation");
1127 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
1128 if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
1130 "expect matching declare attribute on variable in declare operation");
1137 if (declAttr.getImplicit() &&
1140 "implicitness must match between declare op and flag on variable");
1174 int parallelism = 0;
1175 parallelism += getGang() ? 1 : 0;
1176 parallelism += getWorker() ? 1 : 0;
1177 parallelism += getVector() ? 1 : 0;
1178 parallelism += getSeq() ? 1 : 0;
1180 if (parallelism > 1)
1181 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
1182 "be present at the same time";
1188 IntegerAttr &gangDim) {
1218 IntegerAttr gangDim) {
1220 p <<
"(" << RoutineOp::getGangDimKeyword() <<
" = " << gangDim.getValue()
1221 <<
" : " << gangDim.getType() <<
")";
1232 return emitOpError(
"cannot be nested in a compute operation");
1244 return emitOpError(
"cannot be nested in a compute operation");
1256 return emitOpError(
"cannot be nested in a compute operation");
1257 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
1258 return emitOpError(
"at least one default_async, device_num, or device_type "
1259 "operand must appear");
1269 if (getDataClauseOperands().empty())
1270 return emitError(
"at least one value must be present in dataOperands");
1274 if (getAsyncOperand() && getAsync())
1275 return emitError(
"async attribute cannot appear with asyncOperand");
1279 if (!getWaitOperands().empty() && getWait())
1280 return emitError(
"wait attribute cannot appear with waitOperands");
1282 if (getWaitDevnum() && getWaitOperands().empty())
1283 return emitError(
"wait_devnum cannot appear without waitOperands");
1285 for (
mlir::Value operand : getDataClauseOperands())
1286 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
1287 operand.getDefiningOp()))
1288 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
1294 unsigned UpdateOp::getNumDataOperands() {
1295 return getDataClauseOperands().size();
1298 Value UpdateOp::getDataOperand(
unsigned i) {
1299 unsigned numOptional = getAsyncOperand() ? 1 : 0;
1300 numOptional += getWaitDevnum() ? 1 : 0;
1301 numOptional += getIfCond() ? 1 : 0;
1302 return getOperand(getWaitOperands().size() + numOptional + i);
1307 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
1317 if (getAsyncOperand() && getAsync())
1318 return emitError(
"async attribute cannot appear with asyncOperand");
1320 if (getWaitDevnum() && getWaitOperands().empty())
1321 return emitError(
"wait_devnum cannot appear without waitOperands");
1326 #define GET_OP_CLASSES
1327 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
1329 #define GET_ATTRDEF_CLASSES
1330 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
1332 #define GET_TYPEDEF_CLASSES
1333 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
1342 [&](
auto entry) {
return entry.getVarPtr(); })
1347 std::optional<mlir::acc::DataClause>
1352 .Case<ACC_DATA_ENTRY_OPS>(
1353 [&](
auto entry) {
return entry.getDataClause(); })
1361 [&](
auto entry) {
return entry.getImplicit(); })
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
static bool isComputeOperation(Operation *op)
static ParseResult parseWorkerClause(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &workerNum, Type &workerNumType, UnitAttr &hasWorker)
void printVectorClause(OpAsmPrinter &p, Operation *op, Value vectorLength, Type vectorLengthType, UnitAttr hasVector)
static ParseResult parseVectorClause(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &vectorLength, Type &vectorLengthType, UnitAttr &hasVector)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, std::optional< OpAsmParser::UnresolvedOperand > &value, Type &valueType, bool &needComa, bool &newValue)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang, IntegerAttr &gangDim)
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static ParseResult parseGangClause(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &gangNum, Type &gangNumType, std::optional< OpAsmParser::UnresolvedOperand > &gangDim, Type &gangDimType, std::optional< OpAsmParser::UnresolvedOperand > &gangStatic, Type &gangStaticType, UnitAttr &hasGang)
void printGangClause(OpAsmPrinter &p, Operation *op, Value gangNum, Type gangNumType, Value gangDim, Type gangDimType, Value gangStatic, Type gangStaticType, UnitAttr hasGang)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
void printWorkerClause(OpAsmPrinter &p, Operation *op, Value workerNum, Type workerNumType, UnitAttr hasWorker)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang, IntegerAttr gangDim)
#define ACC_DATA_ENTRY_OPS
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
MLIRContext * getContext() const
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::Value getVarPtr(mlir::Operation *accDataEntryOp)
Used to obtain the varPtr from a data entry operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.