34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/StringSaver.h"
44 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
50 int64_t GPUBlockMappingAttr::getMappingId()
const {
51 return static_cast<int64_t
>(getBlock());
54 bool GPUBlockMappingAttr::isLinearMapping()
const {
55 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
58 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
59 return isLinearMapping()
60 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
64 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
65 return static_cast<int64_t
>(getWarpgroup());
68 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
69 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
72 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
73 return isLinearMapping()
74 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
78 int64_t GPUWarpMappingAttr::getMappingId()
const {
79 return static_cast<int64_t
>(getWarp());
82 bool GPUWarpMappingAttr::isLinearMapping()
const {
83 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
86 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
87 return isLinearMapping()
88 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
92 int64_t GPUThreadMappingAttr::getMappingId()
const {
93 return static_cast<int64_t
>(getThread());
96 bool GPUThreadMappingAttr::isLinearMapping()
const {
97 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
100 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
101 return isLinearMapping()
102 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
106 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
107 return static_cast<int64_t
>(getAddressSpace());
110 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
111 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
114 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
115 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
132 elementType, operand);
146 return elementType.
isF16() || elementType.
isF32() ||
155 if (!operand.equals(
"AOp") && !operand.equals(
"BOp") &&
156 !operand.equals(
"COp"))
157 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
159 if (shape.size() != 2)
160 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
164 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
173 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
176 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177 return gpuAttr.getValue() == getWorkgroupAddressSpace();
181 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
182 Attribute memorySpace = type.getMemorySpace();
183 return isWorkgroupMemoryAddressSpace(memorySpace);
186 bool GPUDialect::isKernel(
Operation *op) {
187 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
188 return static_cast<bool>(isKernelAttr);
204 void GPUDialect::initialize() {
205 addTypes<AsyncTokenType>();
206 addTypes<MMAMatrixType>();
207 addTypes<SparseDnTensorHandleType>();
208 addTypes<SparseSpMatHandleType>();
209 addTypes<SparseSpGEMMOpHandleType>();
212 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
215 #define GET_ATTRDEF_LIST
216 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
218 addInterfaces<GPUInlinerInterface>();
219 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
226 return "sparse.dntensor_handle";
228 return "sparse.spmat_handle";
230 return "sparse.spgemmop_handle";
232 llvm_unreachable(
"unknown sparse handle kind");
244 if (keyword ==
"async.token")
247 if (keyword ==
"mma_matrix") {
276 shape, elementType, operand);
294 .Case<SparseDnTensorHandleType>([&](
Type) {
297 .Case<SparseSpMatHandleType>(
299 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
305 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
308 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
310 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
315 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
316 attr.
getName() != getContainerModuleAttrName())
319 auto module = dyn_cast<ModuleOp>(op);
322 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
323 << ModuleOp::getOperationName() <<
'\'';
325 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
328 if (!launchOp->getParentOp() ||
329 launchOp->getParentOp()->getParentOp() != module)
334 if (!launchOp->getAttrOfType<SymbolRefAttr>(
335 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
339 StringAttr kernelContainerName = launchOp.getKernelModuleName();
340 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
341 if (!kernelContainer)
343 <<
"kernel container '" << kernelContainerName.getValue()
347 if (isa<BinaryOp>(kernelContainer))
350 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
352 return launchOp.emitOpError()
353 <<
"kernel module '" << kernelContainerName.getValue()
357 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
360 << launchOp.getKernel() <<
"' is undefined";
361 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
362 if (!kernelConvertedFunction) {
364 <<
"referenced kernel '" << launchOp.getKernel()
365 <<
"' is not a function";
366 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
371 GPUDialect::getKernelFuncAttrName()))
372 return launchOp.emitOpError(
"kernel function is missing the '")
373 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
378 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
379 if (!kernelGPUFunction)
382 unsigned actualNumArguments = launchOp.getNumKernelOperands();
383 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
384 if (expectedNumArguments != actualNumArguments)
385 return launchOp.emitOpError(
"got ")
386 << actualNumArguments <<
" kernel operands but expected "
387 << expectedNumArguments;
389 auto functionType = kernelGPUFunction.getFunctionType();
390 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
391 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
392 return launchOp.emitOpError(
"type of function argument ")
393 << i <<
" does not match";
413 return parser.
emitError(loc,
"needs to be named when marked 'async'");
428 if (asyncDependencies.empty())
433 llvm::interleaveComma(asyncDependencies, printer);
462 p <<
' ' << keyword <<
'(';
463 llvm::interleaveComma(
471 gpu::AddressSpace memorySpace) {
472 for (
Value v : attributions) {
473 auto type = llvm::dyn_cast<MemRefType>(v.getType());
475 return op->
emitOpError() <<
"expected memref type in attribution";
480 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
483 if (addressSpace.getValue() != memorySpace)
485 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
486 <<
" in attribution";
497 using Kind = gpu::AllReduceOperation;
498 if (llvm::is_contained(
499 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
501 if (!isa<FloatType>(resType))
505 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
506 Kind::AND, Kind::OR, Kind::XOR},
508 if (!isa<IntegerType>(resType))
516 if (getBody().empty() != getOp().has_value())
517 return emitError(
"expected either an op attribute or a non-empty body");
518 if (!getBody().empty()) {
519 if (getBody().getNumArguments() != 2)
520 return emitError(
"expected two region arguments");
521 for (
auto argument : getBody().getArguments()) {
522 if (argument.getType() != getType())
523 return emitError(
"incorrect region argument type");
525 unsigned yieldCount = 0;
526 for (
Block &block : getBody()) {
527 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
528 if (yield.getNumOperands() != 1)
529 return emitError(
"expected one gpu.yield operand");
530 if (yield.getOperand(0).getType() != getType())
531 return emitError(
"incorrect gpu.yield type");
536 return emitError(
"expected gpu.yield op in region");
538 gpu::AllReduceOperation opName = *getOp();
540 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
541 <<
"` reduction operation is not compatible with type "
550 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
554 Region &body = launchOp.getBody();
555 assert(!body.
empty() &&
"Invalid region");
572 AllReduceOperationAttr &attr) {
575 std::optional<AllReduceOperation> op =
576 gpu::symbolizeAllReduceOperation(enumStr);
585 AllReduceOperationAttr attr) {
595 Type elemType = getType();
596 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
597 if (vecTy.isScalable())
598 return emitOpError() <<
"is not compatible with scalable vector types";
600 elemType = vecTy.getElementType();
603 gpu::AllReduceOperation opName = getOp();
605 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
606 <<
"` reduction operation is not compatible with type "
612 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
627 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
631 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
649 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
658 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
667 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
668 getBlockSizeY, getBlockSizeZ});
675 if (dynamicSharedMemorySize)
684 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
687 for (
Type argTy : workgroupAttributions)
689 for (
Type argTy : privateAttributions)
693 segmentSizes.front() = asyncDependencies.size();
694 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
695 segmentSizes[7] = clusterSizeX ? 1 : 0;
696 segmentSizes[8] = clusterSizeY ? 1 : 0;
697 segmentSizes[9] = clusterSizeZ ? 1 : 0;
703 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
704 auto args = getBody().getArguments();
709 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
710 auto args = getBody().getArguments();
715 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
716 auto args = getBody().getArguments();
721 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
722 auto args = getBody().getArguments();
723 return KernelDim3{args[9], args[10], args[11]};
726 std::optional<KernelDim3> LaunchOp::getClusterIds() {
727 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
728 if (!hasClusterSize())
730 auto args = getBody().getArguments();
731 return KernelDim3{args[12], args[13], args[14]};
734 std::optional<KernelDim3> LaunchOp::getClusterSize() {
735 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
736 if (!hasClusterSize())
738 auto args = getBody().getArguments();
739 return KernelDim3{args[15], args[16], args[17]};
742 KernelDim3 LaunchOp::getGridSizeOperandValues() {
743 auto operands = getOperands().drop_front(getAsyncDependencies().size());
744 return KernelDim3{operands[0], operands[1], operands[2]};
747 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
748 auto operands = getOperands().drop_front(getAsyncDependencies().size());
749 return KernelDim3{operands[3], operands[4], operands[5]};
752 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
753 auto operands = getOperands().drop_front(getAsyncDependencies().size());
754 if (!hasClusterSize())
756 return KernelDim3{operands[6], operands[7], operands[8]};
760 if (!(hasClusterSize()) &&
761 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
762 return emitOpError() <<
"cluster size must be all present";
770 if (!getBody().empty()) {
771 if (getBody().getNumArguments() <
772 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
773 return emitOpError(
"unexpected number of region arguments");
778 GPUDialect::getWorkgroupAddressSpace())) ||
780 GPUDialect::getPrivateAddressSpace())))
785 for (
Block &block : getBody()) {
788 if (block.back().getNumSuccessors() != 0)
790 if (!isa<gpu::TerminatorOp>(&block.back())) {
793 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
794 "' or a terminator with successors")
795 .attachNote(getLoc())
796 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
800 if (getNumResults() == 0 && getAsyncToken())
801 return emitOpError(
"needs to be named when async keyword is specified");
812 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
813 p << size.
x <<
" = " << operands.
x <<
", ";
814 p << size.
y <<
" = " << operands.
y <<
", ";
815 p << size.
z <<
" = " << operands.
z <<
')';
819 if (getAsyncToken()) {
821 if (!getAsyncDependencies().empty())
822 p <<
" [" << getAsyncDependencies() <<
']';
825 if (hasClusterSize()) {
826 p <<
' ' << getClustersKeyword();
828 getClusterSizeOperandValues().value(),
829 getClusterIds().value());
831 p <<
' ' << getBlocksKeyword();
834 p <<
' ' << getThreadsKeyword();
837 if (getDynamicSharedMemorySize())
838 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
839 << getDynamicSharedMemorySize();
848 LaunchOp::getOperandSegmentSizeAttr(),
849 getNumWorkgroupAttributionsAttrName()});
863 assert(indices.size() == 3 &&
"space for three indices expected");
869 std::move(args.begin(), args.end(), indices.begin());
871 for (
int i = 0; i < 3; ++i) {
893 sizes(LaunchOp::kNumConfigOperands);
900 LaunchOp::kNumConfigRegionAttributes);
911 result.
types.push_back(asyncTokenType);
913 bool hasCluster =
false;
918 regionArgs.resize(18);
927 regionArgsRef.slice(15, 3),
928 regionArgsRef.slice(12, 3)))
936 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
938 regionArgsRef.slice(6, 3),
939 regionArgsRef.slice(0, 3)) ||
940 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
942 regionArgsRef.slice(9, 3),
943 regionArgsRef.slice(3, 3)) ||
949 bool hasDynamicSharedMemorySize =
false;
951 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
952 hasDynamicSharedMemorySize =
true;
968 LaunchOp::kNumConfigRegionAttributes + 6, index);
971 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
973 arg.
ssaName = std::get<0>(ssaValueAndType);
974 arg.
type = std::get<1>(ssaValueAndType);
975 regionArguments.push_back(arg);
986 unsigned numWorkgroupAttrs = regionArguments.size() -
987 LaunchOp::kNumConfigRegionAttributes -
988 (hasCluster ? 6 : 0);
989 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1006 segmentSizes.front() = asyncDependencies.size();
1009 segmentSizes[7] = 0;
1010 segmentSizes[8] = 0;
1011 segmentSizes[9] = 0;
1013 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1014 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1028 bool simplified =
false;
1029 auto constPropIdUses = [&](
Value id,
Value size) {
1033 if (
id.getUses().empty())
1040 rewriter.
create<arith::ConstantIndexOp>(op.
getLoc(), 0);
1045 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1046 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1047 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1048 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1049 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1050 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1064 auto attrName = getNumWorkgroupAttributionsAttrName();
1065 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1066 (*this)->setAttr(attrName,
1068 return getBody().insertArgument(
1069 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1077 return getBody().addArgument(type, loc);
1089 std::optional<KernelDim3> clusterSize) {
1097 if (clusterSize.has_value())
1098 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1099 if (dynamicSharedMemorySize)
1102 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1105 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1108 prop.kernel = kernelSymbol;
1109 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1111 for (
auto &sz : prop.operandSegmentSizes)
1113 prop.operandSegmentSizes[0] = asyncDependencies.size();
1114 if (!clusterSize.has_value()) {
1115 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1116 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1117 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1119 prop.operandSegmentSizes[segmentSizesLen - 3] =
1120 dynamicSharedMemorySize ? 1 : 0;
1121 prop.operandSegmentSizes[segmentSizesLen - 2] =
1122 static_cast<int32_t
>(kernelOperands.size());
1123 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1130 std::optional<KernelDim3> clusterSize) {
1134 if (clusterSize.has_value())
1135 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1136 if (dynamicSharedMemorySize)
1142 prop.kernel = kernel;
1143 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1145 for (
auto &sz : prop.operandSegmentSizes)
1147 prop.operandSegmentSizes[0] = 0;
1148 if (!clusterSize.has_value()) {
1149 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1150 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1151 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1153 prop.operandSegmentSizes[segmentSizesLen - 3] =
1154 dynamicSharedMemorySize ? 1 : 0;
1155 prop.operandSegmentSizes[segmentSizesLen - 2] =
1156 static_cast<int32_t
>(kernelOperands.size());
1157 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1160 StringAttr LaunchFuncOp::getKernelModuleName() {
1164 StringAttr LaunchFuncOp::getKernelName() {
1168 unsigned LaunchFuncOp::getNumKernelOperands() {
1169 return getKernelOperands().size();
1172 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1173 return getKernelOperands()[i];
1176 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1177 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1178 return KernelDim3{operands[0], operands[1], operands[2]};
1181 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1182 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1183 return KernelDim3{operands[3], operands[4], operands[5]};
1186 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1187 assert(hasClusterSize() &&
1188 "cluster size is not set, check hasClusterSize() first");
1189 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1190 return KernelDim3{operands[6], operands[7], operands[8]};
1194 auto module = (*this)->getParentOfType<ModuleOp>();
1196 return emitOpError(
"expected to belong to a module");
1198 if (!module->getAttrOfType<UnitAttr>(
1199 GPUDialect::getContainerModuleAttrName()))
1200 return emitOpError(
"expected the closest surrounding module to have the '" +
1201 GPUDialect::getContainerModuleAttrName() +
1204 if (hasClusterSize()) {
1205 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1206 getClusterSizeZ().getType() != getClusterSizeX().getType())
1207 return emitOpError()
1208 <<
"expects types of the cluster dimensions must be the same";
1216 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1217 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1224 if (clusterValue.has_value()) {
1225 clusterXTy = clusterYTy = clusterZTy = dimTy;
1232 Type clusterYTy,
Type clusterZTy) {
1234 printer <<
": " << dimTy;
1250 parseElement,
" in argument list");
1255 if (operands.empty())
1258 llvm::interleaveComma(llvm::zip(operands, types), printer,
1259 [&](
const auto &pair) {
1272 int32_t offset, int32_t width, ShuffleMode mode) {
1273 build(builder, result, value,
1290 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1301 results.
add(eraseRedundantGpuBarrierOps);
1311 auto attrName = getNumWorkgroupAttributionsAttrName();
1312 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1313 (*this)->setAttr(attrName,
1315 return getBody().insertArgument(
1316 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1324 return getBody().addArgument(type, loc);
1328 StringRef name, FunctionType type,
1338 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1345 for (
Type argTy : type.getInputs())
1347 for (
Type argTy : workgroupAttributions)
1349 for (
Type argTy : privateAttributions)
1368 size_t existingArgs = args.size();
1375 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1380 attributionAttrs =
nullptr;
1386 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1387 if (!argument.attrs)
1390 attributionAttrsVec.push_back(argument.attrs);
1392 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1408 StringAttr nameAttr;
1415 parser,
false, entryArgs, isVariadic, resultTypes,
1419 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1420 return parser.
emitError(signatureLocation)
1421 <<
"gpu.func requires named arguments";
1428 for (
auto &arg : entryArgs)
1429 argTypes.push_back(arg.
type);
1435 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1436 getResAttrsAttrName(result.
name));
1441 entryArgs, workgroupAttributionAttrs)))
1446 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1447 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1449 if (workgroupAttributionAttrs)
1450 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1451 workgroupAttributionAttrs);
1456 entryArgs, privateAttributionAttrs)))
1458 if (privateAttributionAttrs)
1460 privateAttributionAttrs);
1464 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1479 ArrayAttr attributes) {
1483 p <<
' ' << keyword <<
'(';
1484 llvm::interleaveComma(
1487 p << v <<
" : " << v.
getType();
1489 size_t attributionIndex = pair.index();
1490 DictionaryAttr attrs;
1491 if (attributes && attributionIndex < attributes.size())
1492 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1503 FunctionType type = getFunctionType();
1509 getWorkgroupAttribAttrs().value_or(
nullptr));
1511 getPrivateAttribAttrs().value_or(
nullptr));
1513 p <<
' ' << getKernelKeyword();
1517 {getNumWorkgroupAttributionsAttrName(),
1518 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1519 getArgAttrsAttrName(), getResAttrsAttrName(),
1520 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1526 StringAttr attrName) {
1527 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1528 if (!allAttrs || index >= allAttrs.size())
1529 return DictionaryAttr();
1530 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1533 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1537 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1542 DictionaryAttr value, StringAttr attrName) {
1544 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1547 elements.append(allAttrs.begin(), allAttrs.end());
1548 while (elements.size() <= index)
1553 elements[index] = value;
1555 op->
setAttr(attrName, newValue);
1558 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1559 DictionaryAttr value) {
1563 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1564 DictionaryAttr value) {
1569 StringAttr name, StringAttr attrsName) {
1573 return dict.get(name);
1576 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1578 assert(index < getNumWorkgroupAttributions() &&
1579 "index must map to a workgroup attribution");
1581 getWorkgroupAttribAttrsAttrName());
1584 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1586 assert(index < getNumPrivateAttributions() &&
1587 "index must map to a private attribution");
1589 getPrivateAttribAttrsAttrName());
1593 Attribute value, StringAttr attrsName) {
1598 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1601 bool mustSort =
true;
1602 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1603 if (elems[i].getName() == name) {
1606 std::swap(elems[i], elems[elems.size() - 1]);
1618 elems.emplace_back(name, value);
1621 DictionaryAttr::sortInPlace(elems);
1623 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1627 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1629 assert(index < getNumWorkgroupAttributions() &&
1630 "index must map to a workgroup attribution");
1632 getWorkgroupAttribAttrsAttrName());
1635 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1637 assert(index < getNumPrivateAttributions() &&
1638 "index must map to a private attribution");
1640 getPrivateAttribAttrsAttrName());
1644 if (isKernel() && getFunctionType().getNumResults() != 0)
1645 return emitOpError() <<
"expected void return type for kernel function";
1653 return emitOpError() <<
"expected body with at least one block";
1654 unsigned numFuncArguments = getNumArguments();
1655 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1656 unsigned numBlockArguments = front().getNumArguments();
1657 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1658 return emitOpError() <<
"expected at least "
1659 << numFuncArguments + numWorkgroupAttributions
1660 <<
" arguments to body region";
1663 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1664 Type blockArgType = front().getArgument(i).getType();
1665 if (funcArgTypes[i] != blockArgType)
1666 return emitOpError() <<
"expected body region argument #" << i
1667 <<
" to be of type " << funcArgTypes[i] <<
", got "
1672 GPUDialect::getWorkgroupAddressSpace())) ||
1674 GPUDialect::getPrivateAddressSpace())))
1681 StringRef attrName) {
1682 auto maybeAttr = op->
getAttr(attrName);
1685 auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
1687 return op.
emitOpError(attrName +
" must be a dense i32 array");
1688 if (array.size() != 3)
1689 return op.
emitOpError(attrName +
" must contain exactly 3 elements");
1706 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1708 FunctionType funType =
function.getFunctionType();
1710 if (funType.getNumResults() != getOperands().size())
1711 return emitOpError()
1712 .append(
"expected ", funType.getNumResults(),
" result operands")
1713 .attachNote(
function.getLoc())
1714 .append(
"return type declared here");
1717 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1718 auto [type, operand] = pair.value();
1719 if (type != operand.getType())
1720 return emitOpError() <<
"unexpected type `" << operand.getType()
1721 <<
"' for operand #" << pair.index();
1731 StringRef name, ArrayAttr targets,
1739 props.targets = targets;
1740 props.offloadingHandler = offloadingHandler;
1746 build(builder, result, name,
1747 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1752 StringAttr nameAttr;
1753 ArrayAttr targetsAttr;
1773 if (
failed(*targetsAttrResult)) {
1776 props.targets = targetsAttr;
1797 if (
Attribute attr = getOffloadingHandlerAttr()) {
1803 if (
Attribute attr = getTargetsAttr()) {
1810 {mlir::SymbolTable::getSymbolAttrName(),
1811 getTargetsAttrName(),
1812 getOffloadingHandlerAttrName()});
1818 bool GPUModuleOp::hasTarget(
Attribute target) {
1819 if (ArrayAttr targets = getTargetsAttr())
1820 return llvm::count(targets.getValue(), target);
1825 ArrayAttr &targetsAttr = getProperties().targets;
1834 Attribute offloadingHandler, ArrayAttr objects) {
1838 properties.objects = objects;
1839 if (offloadingHandler)
1840 properties.offloadingHandler = offloadingHandler;
1842 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1847 build(builder, result, name, offloadingHandler,
1848 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1859 if (!offloadingHandler)
1867 printer << '<' << offloadingHandler << '>
';
1870 //===----------------------------------------------------------------------===//
1872 //===----------------------------------------------------------------------===//
1874 LogicalResult MemcpyOp::verify() {
1875 auto srcType = getSrc().getType();
1876 auto dstType = getDst().getType();
1878 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1879 return emitOpError("arguments have incompatible element type");
1881 if (failed(verifyCompatibleShape(srcType, dstType)))
1882 return emitOpError("arguments have incompatible shape");
1891 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1892 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1894 LogicalResult matchAndRewrite(MemcpyOp op,
1895 PatternRewriter &rewriter) const override {
1896 Value dest = op.getDst();
1897 Operation *destDefOp = dest.getDefiningOp();
1898 // `dest` must be defined by an op having Allocate memory effect in order to
1899 // perform the folding.
1901 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1903 // We can erase `op` iff `dest` has no other use apart from its
1904 // use by `op` and dealloc ops.
1905 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1906 return user != op &&
1907 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1910 // We can perform the folding if and only if op has a single async
1911 // dependency and produces an async token as result, or if it does not have
1912 // any async dependency and does not produce any async token result.
1913 if (op.getAsyncDependencies().size() > 1 ||
1914 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1915 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1917 rewriter.replaceOp(op, op.getAsyncDependencies());
1922 } // end anonymous namespace
1924 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1925 MLIRContext *context) {
1926 results.add<EraseTrivialCopyOp>(context);
1929 //===----------------------------------------------------------------------===//
1930 // GPU_SubgroupMmaLoadMatrixOp
1931 //===----------------------------------------------------------------------===//
1933 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1934 auto srcType = getSrcMemref().getType();
1935 auto resType = getRes().getType();
1936 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1937 auto operand = resMatrixType.getOperand();
1938 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1940 if (!isLastMemrefDimUnitStride(srcMemrefType))
1942 "expected source memref most minor dim must have unit stride");
1944 if (!operand.equals("AOp") && !operand.equals("BOp") &&
1945 !operand.equals("COp"))
1946 return emitError("only AOp, BOp and COp can be loaded");
1951 //===----------------------------------------------------------------------===//
1952 // GPU_SubgroupMmaStoreMatrixOp
1953 //===----------------------------------------------------------------------===//
1955 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1956 auto srcType = getSrc().getType();
1957 auto dstType = getDstMemref().getType();
1958 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1959 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1961 if (!isLastMemrefDimUnitStride(dstMemrefType))
1963 "expected destination memref most minor dim must have unit stride");
1965 if (!srcMatrixType.getOperand().equals("COp"))
1967 "expected the operand matrix being stored to have 'COp
' operand type");
1972 //===----------------------------------------------------------------------===//
1973 // GPU_SubgroupMmaComputeOp
1974 //===----------------------------------------------------------------------===//
1976 LogicalResult SubgroupMmaComputeOp::verify() {
1977 enum OperandMap { A, B, C };
1978 SmallVector<MMAMatrixType, 3> opTypes;
1979 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1980 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1981 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1983 if (!opTypes[A].getOperand().equals("AOp") ||
1984 !opTypes[B].getOperand().equals("BOp") ||
1985 !opTypes[C].getOperand().equals("COp"))
1986 return emitError("operands must be in the order AOp, BOp, COp");
1988 ArrayRef<int64_t> aShape, bShape, cShape;
1989 aShape = opTypes[A].getShape();
1990 bShape = opTypes[B].getShape();
1991 cShape = opTypes[C].getShape();
1993 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1994 bShape[1] != cShape[1])
1995 return emitError("operand shapes do not satisfy matmul constraints");
2000 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2001 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2002 return memref::foldMemRefCast(*this);
2005 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2006 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2007 return memref::foldMemRefCast(*this);
2010 //===----------------------------------------------------------------------===//
2012 //===----------------------------------------------------------------------===//
2019 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2021 using OpRewritePattern::OpRewritePattern;
2023 LogicalResult matchAndRewrite(WaitOp op,
2024 PatternRewriter &rewriter) const final {
2025 auto predicate = [](Value value) {
2026 auto waitOp = value.getDefiningOp<WaitOp>();
2027 return waitOp && waitOp->getNumOperands() == 0;
2029 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2031 SmallVector<Value> validOperands;
2032 for (Value operand : op->getOperands()) {
2033 if (predicate(operand))
2035 validOperands.push_back(operand);
2037 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2049 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2051 using OpRewritePattern::OpRewritePattern;
2053 LogicalResult matchAndRewrite(WaitOp op,
2054 PatternRewriter &rewriter) const final {
2055 // Erase gpu.wait ops that neither have any async dependencies nor return
2057 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2058 rewriter.eraseOp(op);
2061 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2062 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2063 op.getAsyncToken()) {
2064 rewriter.replaceOp(op, op.getAsyncDependencies());
2067 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2068 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2069 rewriter.eraseOp(op);
2076 } // end anonymous namespace
2078 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2079 MLIRContext *context) {
2080 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2083 //===----------------------------------------------------------------------===//
2085 //===----------------------------------------------------------------------===//
2087 LogicalResult AllocOp::verify() {
2088 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2090 if (static_cast<int64_t>(getDynamicSizes().size()) !=
2091 memRefType.getNumDynamicDims())
2092 return emitOpError("dimension operand count does not equal memref "
2093 "dynamic dimension count");
2095 unsigned numSymbols = 0;
2096 if (!memRefType.getLayout().isIdentity())
2097 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2098 if (getSymbolOperands().size() != numSymbols) {
2100 "symbol operand count does not equal memref symbol count");
2110 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2111 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2113 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2114 PatternRewriter &rewriter) const override {
2115 std::optional<int64_t> index = dimOp.getConstantIndex();
2119 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2120 if (!memrefType || !memrefType.isDynamicDim(index.value()))
2123 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2127 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2128 memrefType.getDynamicDimIndex(index.value()));
2129 rewriter.replaceOp(dimOp, substituteOp);
2136 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2137 MLIRContext *context) {
2138 results.add<SimplifyDimOfAllocOp>(context);
2141 //===----------------------------------------------------------------------===//
2142 // GPU object attribute
2143 //===----------------------------------------------------------------------===//
2145 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2146 Attribute target, CompilationTarget format,
2147 StringAttr object, DictionaryAttr properties) {
2149 return emitError() << "the target attribute cannot be null";
2150 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2152 return emitError() << "the target attribute must implement or promise the "
2153 "`gpu::TargetAttrInterface`";
2157 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2158 StringAttr &object) {
2159 std::optional<CompilationTarget> formatResult;
2160 StringRef enumKeyword;
2161 auto loc = odsParser.getCurrentLocation();
2162 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2163 formatResult = CompilationTarget::Fatbin;
2164 if (!formatResult &&
2166 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2167 odsParser.parseEqual())
2168 return odsParser.emitError(loc, "expected an equal sign");
2170 return odsParser.emitError(loc, "expected keyword for GPU object format");
2171 FailureOr<StringAttr> objectResult =
2172 FieldParser<StringAttr>::parse(odsParser);
2173 if (failed(objectResult))
2174 return odsParser.emitError(odsParser.getCurrentLocation(),
2175 "failed to parse GPU_ObjectAttr parameter "
2176 "'
object' which is to be a `StringAttr`");
2177 format = *formatResult;
2178 object = *objectResult;
2182 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2183 StringAttr object) {
2184 if (format != CompilationTarget::Fatbin)
2185 odsParser << stringifyEnum(format) << " = ";
2186 odsParser << object;
2190 //===----------------------------------------------------------------------===//
2191 // GPU select object attribute
2192 //===----------------------------------------------------------------------===//
2195 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2197 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2199 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2200 if (intAttr.getInt() < 0) {
2201 return emitError() << "the object index must be positive";
2203 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2205 << "the target attribute must be a GPU Target attribute";
2211 //===----------------------------------------------------------------------===//
2212 // DynamicSharedMemoryOp
2213 //===----------------------------------------------------------------------===//
2215 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2216 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2217 return emitOpError() << "must be inside an op with symbol table";
2219 MemRefType memrefType = getResultMemref().getType();
2220 // Check address space
2221 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2222 return emitOpError() << "address space must be "
2223 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2224 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2226 if (memrefType.hasStaticShape()) {
2227 return emitOpError() << "result memref type must be memref<?xi8, "
2228 "#gpu.address_space<workgroup>>";
2233 //===----------------------------------------------------------------------===//
2234 // GPU target options
2235 //===----------------------------------------------------------------------===//
2237 TargetOptions::TargetOptions(
2238 StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2239 StringRef cmdOptions, CompilationTarget compilationTarget,
2240 function_ref<SymbolTable *()> getSymbolTableCallback)
2241 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2242 cmdOptions, compilationTarget, getSymbolTableCallback) {}
2244 TargetOptions::TargetOptions(
2245 TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2246 StringRef cmdOptions, CompilationTarget compilationTarget,
2247 function_ref<SymbolTable *()> getSymbolTableCallback)
2248 : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2249 cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2250 getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2252 TypeID TargetOptions::getTypeID() const { return typeID; }
2254 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2256 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2258 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2260 SymbolTable *TargetOptions::getSymbolTable() const {
2261 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2264 CompilationTarget TargetOptions::getCompilationTarget() const {
2265 return compilationTarget;
2268 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2269 return CompilationTarget::Fatbin;
2272 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2273 TargetOptions::tokenizeCmdOptions() const {
2274 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2275 llvm::StringSaver stringSaver(options.first);
2276 StringRef opts = cmdOptions;
2277 // For a correct tokenization of the command line options `opts` must be
2278 // unquoted, otherwise the tokenization function returns a single string: the
2279 // unquoted `cmdOptions` -which is not the desired behavior.
2280 // Remove any quotes if they are at the beginning and end of the string:
2281 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2282 opts.consume_front("\""), opts.consume_back("\"");
2283 if (!opts.empty() && opts.front() == '\
'' && opts.back() ==
'\'')
2284 opts.consume_front(
"'"), opts.consume_back(
"'");
2286 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2289 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2297 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2298 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2300 #define GET_ATTRDEF_CLASSES
2301 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2303 #define GET_OP_CLASSES
2304 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2306 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op, StringRef attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....