33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/StringSaver.h"
43 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
49 int64_t GPUBlockMappingAttr::getMappingId()
const {
50 return static_cast<int64_t
>(getBlock());
53 bool GPUBlockMappingAttr::isLinearMapping()
const {
54 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
57 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
58 return isLinearMapping()
59 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
63 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
64 return static_cast<int64_t
>(getWarpgroup());
67 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
68 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
71 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
72 return isLinearMapping()
73 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
77 int64_t GPUWarpMappingAttr::getMappingId()
const {
78 return static_cast<int64_t
>(getWarp());
81 bool GPUWarpMappingAttr::isLinearMapping()
const {
82 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
85 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
86 return isLinearMapping()
87 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
91 int64_t GPUThreadMappingAttr::getMappingId()
const {
92 return static_cast<int64_t
>(getThread());
95 bool GPUThreadMappingAttr::isLinearMapping()
const {
96 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
99 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
100 return isLinearMapping()
101 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
105 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
106 return static_cast<int64_t
>(getAddressSpace());
109 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
110 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
113 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
114 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
131 elementType, operand);
145 return elementType.
isF16() || elementType.
isF32() ||
154 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
155 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
157 if (shape.size() != 2)
158 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
162 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
171 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
174 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
175 return gpuAttr.getValue() == getWorkgroupAddressSpace();
179 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
180 Attribute memorySpace = type.getMemorySpace();
181 return isWorkgroupMemoryAddressSpace(memorySpace);
184 bool GPUDialect::isKernel(
Operation *op) {
185 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
186 return static_cast<bool>(isKernelAttr);
202 void GPUDialect::initialize() {
203 addTypes<AsyncTokenType>();
204 addTypes<MMAMatrixType>();
205 addTypes<SparseDnTensorHandleType>();
206 addTypes<SparseSpMatHandleType>();
207 addTypes<SparseSpGEMMOpHandleType>();
210 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
213 #define GET_ATTRDEF_LIST
214 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
216 addInterfaces<GPUInlinerInterface>();
217 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
224 return "sparse.dntensor_handle";
226 return "sparse.spmat_handle";
228 return "sparse.spgemmop_handle";
230 llvm_unreachable(
"unknown sparse handle kind");
242 if (keyword ==
"async.token")
245 if (keyword ==
"mma_matrix") {
274 shape, elementType, operand);
292 .Case<SparseDnTensorHandleType>([&](
Type) {
295 .Case<SparseSpMatHandleType>(
297 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
303 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
306 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
308 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
313 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
316 " must be a dense i32 array");
317 if (array.size() != 3)
319 " must contain exactly 3 elements");
323 LogicalResult GPUDialect::verifyOperationAttribute(
Operation *op,
325 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
327 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
329 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
330 attr.
getName() != getContainerModuleAttrName())
333 auto module = dyn_cast<ModuleOp>(op);
336 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
337 << ModuleOp::getOperationName() <<
'\'';
339 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
342 if (!launchOp->getParentOp() ||
343 launchOp->getParentOp()->getParentOp() != module)
348 if (!launchOp->getAttrOfType<SymbolRefAttr>(
349 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
353 StringAttr kernelContainerName = launchOp.getKernelModuleName();
354 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
355 if (!kernelContainer)
357 <<
"kernel container '" << kernelContainerName.getValue()
361 if (isa<BinaryOp>(kernelContainer))
364 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
366 return launchOp.emitOpError()
367 <<
"kernel module '" << kernelContainerName.getValue()
371 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
374 << launchOp.getKernel() <<
"' is undefined";
375 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
376 if (!kernelConvertedFunction) {
378 <<
"referenced kernel '" << launchOp.getKernel()
379 <<
"' is not a function";
380 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
385 GPUDialect::getKernelFuncAttrName()))
386 return launchOp.emitOpError(
"kernel function is missing the '")
387 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
392 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
393 if (!kernelGPUFunction)
396 unsigned actualNumArguments = launchOp.getNumKernelOperands();
397 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
398 if (expectedNumArguments != actualNumArguments)
399 return launchOp.emitOpError(
"got ")
400 << actualNumArguments <<
" kernel operands but expected "
401 << expectedNumArguments;
403 auto functionType = kernelGPUFunction.getFunctionType();
404 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
405 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
406 return launchOp.emitOpError(
"type of function argument ")
407 << i <<
" does not match";
414 return walkResult.wasInterrupted() ? failure() : success();
427 return parser.
emitError(loc,
"needs to be named when marked 'async'");
442 if (asyncDependencies.empty())
447 llvm::interleaveComma(asyncDependencies, printer);
476 p <<
' ' << keyword <<
'(';
477 llvm::interleaveComma(
485 gpu::AddressSpace memorySpace) {
486 for (
Value v : attributions) {
487 auto type = llvm::dyn_cast<MemRefType>(v.getType());
489 return op->
emitOpError() <<
"expected memref type in attribution";
494 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
497 if (addressSpace.getValue() != memorySpace)
499 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
500 <<
" in attribution";
511 using Kind = gpu::AllReduceOperation;
512 if (llvm::is_contained(
513 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
515 if (!isa<FloatType>(resType))
519 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
520 Kind::AND, Kind::OR, Kind::XOR},
522 if (!isa<IntegerType>(resType))
529 LogicalResult gpu::AllReduceOp::verifyRegions() {
530 if (getBody().empty() != getOp().has_value())
531 return emitError(
"expected either an op attribute or a non-empty body");
532 if (!getBody().empty()) {
533 if (getBody().getNumArguments() != 2)
534 return emitError(
"expected two region arguments");
535 for (
auto argument : getBody().getArguments()) {
536 if (argument.getType() !=
getType())
537 return emitError(
"incorrect region argument type");
539 unsigned yieldCount = 0;
540 for (
Block &block : getBody()) {
541 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
542 if (yield.getNumOperands() != 1)
543 return emitError(
"expected one gpu.yield operand");
544 if (yield.getOperand(0).getType() !=
getType())
545 return emitError(
"incorrect gpu.yield type");
550 return emitError(
"expected gpu.yield op in region");
552 gpu::AllReduceOperation opName = *getOp();
554 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
555 <<
"` reduction operation is not compatible with type "
564 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
568 Region &body = launchOp.getBody();
569 assert(!body.
empty() &&
"Invalid region");
586 AllReduceOperationAttr &attr) {
589 std::optional<AllReduceOperation> op =
590 gpu::symbolizeAllReduceOperation(enumStr);
599 AllReduceOperationAttr attr) {
610 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
611 if (vecTy.isScalable())
612 return emitOpError() <<
"is not compatible with scalable vector types";
614 elemType = vecTy.getElementType();
617 gpu::AllReduceOperation opName = getOp();
619 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
620 <<
"` reduction operation is not compatible with type "
624 auto clusterSize = getClusterSize();
626 uint32_t size = *clusterSize;
627 if (!llvm::isPowerOf2_32(size)) {
628 return emitOpError() <<
"cluster size " << size
629 <<
" is not a power of two";
633 uint32_t stride = getClusterStride();
634 if (stride != 1 && !clusterSize) {
635 return emitOpError() <<
"cluster stride can only be specified if cluster "
638 if (!llvm::isPowerOf2_32(stride)) {
639 return emitOpError() <<
"cluster stride " << stride
640 <<
" is not a power of two";
646 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
647 if (getClusterSize() == 1)
664 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
668 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
686 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
695 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
704 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
705 getBlockSizeY, getBlockSizeZ});
712 if (dynamicSharedMemorySize)
721 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
724 for (
Type argTy : workgroupAttributions)
726 for (
Type argTy : privateAttributions)
730 segmentSizes.front() = asyncDependencies.size();
731 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
732 segmentSizes[7] = clusterSizeX ? 1 : 0;
733 segmentSizes[8] = clusterSizeY ? 1 : 0;
734 segmentSizes[9] = clusterSizeZ ? 1 : 0;
740 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
741 auto args = getBody().getArguments();
746 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
747 auto args = getBody().getArguments();
752 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
753 auto args = getBody().getArguments();
758 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
759 auto args = getBody().getArguments();
760 return KernelDim3{args[9], args[10], args[11]};
763 std::optional<KernelDim3> LaunchOp::getClusterIds() {
764 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
765 if (!hasClusterSize())
767 auto args = getBody().getArguments();
768 return KernelDim3{args[12], args[13], args[14]};
771 std::optional<KernelDim3> LaunchOp::getClusterSize() {
772 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
773 if (!hasClusterSize())
775 auto args = getBody().getArguments();
776 return KernelDim3{args[15], args[16], args[17]};
779 KernelDim3 LaunchOp::getGridSizeOperandValues() {
780 auto operands = getOperands().drop_front(getAsyncDependencies().size());
781 return KernelDim3{operands[0], operands[1], operands[2]};
784 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
785 auto operands = getOperands().drop_front(getAsyncDependencies().size());
786 return KernelDim3{operands[3], operands[4], operands[5]};
789 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
790 auto operands = getOperands().drop_front(getAsyncDependencies().size());
791 if (!hasClusterSize())
793 return KernelDim3{operands[6], operands[7], operands[8]};
797 if (!(hasClusterSize()) &&
798 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
799 return emitOpError() <<
"cluster size must be all present";
803 LogicalResult LaunchOp::verifyRegions() {
807 if (!getBody().empty()) {
808 if (getBody().getNumArguments() <
809 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
810 return emitOpError(
"unexpected number of region arguments");
815 GPUDialect::getWorkgroupAddressSpace())) ||
817 GPUDialect::getPrivateAddressSpace())))
822 for (
Block &block : getBody()) {
825 if (block.back().getNumSuccessors() != 0)
827 if (!isa<gpu::TerminatorOp>(&block.back())) {
830 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
831 "' or a terminator with successors")
832 .attachNote(getLoc())
833 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
837 if (getNumResults() == 0 && getAsyncToken())
838 return emitOpError(
"needs to be named when async keyword is specified");
849 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
850 p << size.
x <<
" = " << operands.
x <<
", ";
851 p << size.
y <<
" = " << operands.
y <<
", ";
852 p << size.
z <<
" = " << operands.
z <<
')';
856 if (getAsyncToken()) {
858 if (!getAsyncDependencies().empty())
859 p <<
" [" << getAsyncDependencies() <<
']';
862 if (hasClusterSize()) {
863 p <<
' ' << getClustersKeyword();
865 getClusterSizeOperandValues().value(),
866 getClusterIds().value());
868 p <<
' ' << getBlocksKeyword();
871 p <<
' ' << getThreadsKeyword();
874 if (getDynamicSharedMemorySize())
875 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
876 << getDynamicSharedMemorySize();
885 LaunchOp::getOperandSegmentSizeAttr(),
886 getNumWorkgroupAttributionsAttrName()});
900 assert(indices.size() == 3 &&
"space for three indices expected");
906 std::move(args.begin(), args.end(), indices.begin());
908 for (
int i = 0; i < 3; ++i) {
930 sizes(LaunchOp::kNumConfigOperands);
937 LaunchOp::kNumConfigRegionAttributes);
948 result.
types.push_back(asyncTokenType);
950 bool hasCluster =
false;
955 regionArgs.resize(18);
964 regionArgsRef.slice(15, 3),
965 regionArgsRef.slice(12, 3)))
973 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
975 regionArgsRef.slice(6, 3),
976 regionArgsRef.slice(0, 3)) ||
977 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
979 regionArgsRef.slice(9, 3),
980 regionArgsRef.slice(3, 3)) ||
986 bool hasDynamicSharedMemorySize =
false;
988 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
989 hasDynamicSharedMemorySize =
true;
1005 LaunchOp::kNumConfigRegionAttributes + 6, index);
1008 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1010 arg.
ssaName = std::get<0>(ssaValueAndType);
1011 arg.
type = std::get<1>(ssaValueAndType);
1012 regionArguments.push_back(arg);
1023 unsigned numWorkgroupAttrs = regionArguments.size() -
1024 LaunchOp::kNumConfigRegionAttributes -
1025 (hasCluster ? 6 : 0);
1026 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1043 segmentSizes.front() = asyncDependencies.size();
1046 segmentSizes[7] = 0;
1047 segmentSizes[8] = 0;
1048 segmentSizes[9] = 0;
1050 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1051 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1065 bool simplified =
false;
1066 auto constPropIdUses = [&](
Value id,
Value size) {
1070 if (
id.getUses().empty())
1077 rewriter.
create<arith::ConstantIndexOp>(op.
getLoc(), 0);
1082 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1083 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1084 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1085 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1086 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1087 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1089 return success(simplified);
1101 auto attrName = getNumWorkgroupAttributionsAttrName();
1102 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1103 (*this)->setAttr(attrName,
1105 return getBody().insertArgument(
1106 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1114 return getBody().addArgument(type, loc);
1122 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1126 std::optional<KernelDim3> clusterSize) {
1127 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1128 "expected a symbol reference with a single nested reference");
1136 if (clusterSize.has_value())
1137 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1138 if (dynamicSharedMemorySize)
1143 prop.kernel = kernelSymbol;
1144 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1146 for (
auto &sz : prop.operandSegmentSizes)
1148 prop.operandSegmentSizes[0] = asyncDependencies.size();
1149 if (!clusterSize.has_value()) {
1150 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1151 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1152 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1154 prop.operandSegmentSizes[segmentSizesLen - 3] =
1155 dynamicSharedMemorySize ? 1 : 0;
1156 prop.operandSegmentSizes[segmentSizesLen - 2] =
1157 static_cast<int32_t
>(kernelOperands.size());
1158 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1166 std::optional<KernelDim3> clusterSize) {
1167 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1170 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1171 build(builder, result, kernelSymbol, gridSize,
getBlockSize,
1172 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1173 asyncDependencies, clusterSize);
1180 std::optional<KernelDim3> clusterSize) {
1184 if (clusterSize.has_value())
1185 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1186 if (dynamicSharedMemorySize)
1192 prop.kernel = kernel;
1193 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1195 for (
auto &sz : prop.operandSegmentSizes)
1197 prop.operandSegmentSizes[0] = 0;
1198 if (!clusterSize.has_value()) {
1199 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1200 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1201 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1203 prop.operandSegmentSizes[segmentSizesLen - 3] =
1204 dynamicSharedMemorySize ? 1 : 0;
1205 prop.operandSegmentSizes[segmentSizesLen - 2] =
1206 static_cast<int32_t
>(kernelOperands.size());
1207 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1210 StringAttr LaunchFuncOp::getKernelModuleName() {
1214 StringAttr LaunchFuncOp::getKernelName() {
1218 unsigned LaunchFuncOp::getNumKernelOperands() {
1219 return getKernelOperands().size();
1222 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1223 return getKernelOperands()[i];
1226 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1227 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1228 return KernelDim3{operands[0], operands[1], operands[2]};
1231 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1232 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1233 return KernelDim3{operands[3], operands[4], operands[5]};
1236 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1237 assert(hasClusterSize() &&
1238 "cluster size is not set, check hasClusterSize() first");
1239 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1240 return KernelDim3{operands[6], operands[7], operands[8]};
1244 auto module = (*this)->getParentOfType<ModuleOp>();
1246 return emitOpError(
"expected to belong to a module");
1248 if (!module->getAttrOfType<UnitAttr>(
1249 GPUDialect::getContainerModuleAttrName()))
1250 return emitOpError(
"expected the closest surrounding module to have the '" +
1251 GPUDialect::getContainerModuleAttrName() +
1254 if (hasClusterSize()) {
1255 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1257 return emitOpError()
1258 <<
"expects types of the cluster dimensions must be the same";
1266 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1267 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1274 if (clusterValue.has_value()) {
1275 clusterXTy = clusterYTy = clusterZTy = dimTy;
1282 Type clusterYTy,
Type clusterZTy) {
1284 printer <<
": " << dimTy;
1294 auto parseElement = [&]() -> ParseResult {
1295 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1300 parseElement,
" in argument list");
1305 if (operands.empty())
1308 llvm::interleaveComma(llvm::zip(operands, types), printer,
1309 [&](
const auto &pair) {
1322 int32_t offset, int32_t width, ShuffleMode mode) {
1323 build(builder, result, value,
1338 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1340 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1351 results.
add(eraseRedundantGpuBarrierOps);
1361 auto attrName = getNumWorkgroupAttributionsAttrName();
1362 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1363 (*this)->setAttr(attrName,
1365 return getBody().insertArgument(
1366 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1374 return getBody().addArgument(type, loc);
1378 StringRef name, FunctionType type,
1388 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1395 for (
Type argTy : type.getInputs())
1397 for (
Type argTy : workgroupAttributions)
1399 for (
Type argTy : privateAttributions)
1418 size_t existingArgs = args.size();
1419 ParseResult result =
1425 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1430 attributionAttrs =
nullptr;
1436 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1437 if (!argument.attrs)
1440 attributionAttrsVec.push_back(argument.attrs);
1442 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1458 StringAttr nameAttr;
1465 parser,
false, entryArgs, isVariadic, resultTypes,
1469 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1470 return parser.
emitError(signatureLocation)
1471 <<
"gpu.func requires named arguments";
1478 for (
auto &arg : entryArgs)
1479 argTypes.push_back(arg.
type);
1485 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1486 getResAttrsAttrName(result.
name));
1491 entryArgs, workgroupAttributionAttrs)))
1496 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1497 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1499 if (workgroupAttributionAttrs)
1500 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1501 workgroupAttributionAttrs);
1506 entryArgs, privateAttributionAttrs)))
1508 if (privateAttributionAttrs)
1510 privateAttributionAttrs);
1514 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1529 ArrayAttr attributes) {
1533 p <<
' ' << keyword <<
'(';
1534 llvm::interleaveComma(
1537 p << v <<
" : " << v.
getType();
1539 size_t attributionIndex = pair.index();
1540 DictionaryAttr attrs;
1541 if (attributes && attributionIndex < attributes.size())
1542 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1553 FunctionType type = getFunctionType();
1559 getWorkgroupAttribAttrs().value_or(
nullptr));
1561 getPrivateAttribAttrs().value_or(
nullptr));
1563 p <<
' ' << getKernelKeyword();
1567 {getNumWorkgroupAttributionsAttrName(),
1568 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1569 getArgAttrsAttrName(), getResAttrsAttrName(),
1570 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1576 StringAttr attrName) {
1577 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1578 if (!allAttrs || index >= allAttrs.size())
1579 return DictionaryAttr();
1580 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1583 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1587 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1592 DictionaryAttr value, StringAttr attrName) {
1594 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1597 elements.append(allAttrs.begin(), allAttrs.end());
1598 while (elements.size() <= index)
1603 elements[index] = value;
1605 op->
setAttr(attrName, newValue);
1608 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1609 DictionaryAttr value) {
1613 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1614 DictionaryAttr value) {
1619 StringAttr name, StringAttr attrsName) {
1623 return dict.get(name);
1626 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1628 assert(index < getNumWorkgroupAttributions() &&
1629 "index must map to a workgroup attribution");
1631 getWorkgroupAttribAttrsAttrName());
1634 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1636 assert(index < getNumPrivateAttributions() &&
1637 "index must map to a private attribution");
1639 getPrivateAttribAttrsAttrName());
1643 Attribute value, StringAttr attrsName) {
1648 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1651 bool mustSort =
true;
1652 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1653 if (elems[i].getName() == name) {
1656 std::swap(elems[i], elems[elems.size() - 1]);
1668 elems.emplace_back(name, value);
1671 DictionaryAttr::sortInPlace(elems);
1673 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1677 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1679 assert(index < getNumWorkgroupAttributions() &&
1680 "index must map to a workgroup attribution");
1682 getWorkgroupAttribAttrsAttrName());
1685 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1687 assert(index < getNumPrivateAttributions() &&
1688 "index must map to a private attribution");
1690 getPrivateAttribAttrsAttrName());
1693 LogicalResult GPUFuncOp::verifyType() {
1694 if (isKernel() && getFunctionType().getNumResults() != 0)
1695 return emitOpError() <<
"expected void return type for kernel function";
1701 LogicalResult GPUFuncOp::verifyBody() {
1703 return emitOpError() <<
"expected body with at least one block";
1704 unsigned numFuncArguments = getNumArguments();
1705 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1706 unsigned numBlockArguments = front().getNumArguments();
1707 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1708 return emitOpError() <<
"expected at least "
1709 << numFuncArguments + numWorkgroupAttributions
1710 <<
" arguments to body region";
1713 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1714 Type blockArgType = front().getArgument(i).getType();
1715 if (funcArgTypes[i] != blockArgType)
1716 return emitOpError() <<
"expected body region argument #" << i
1717 <<
" to be of type " << funcArgTypes[i] <<
", got "
1722 GPUDialect::getWorkgroupAddressSpace())) ||
1724 GPUDialect::getPrivateAddressSpace())))
1735 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1737 FunctionType funType =
function.getFunctionType();
1739 if (funType.getNumResults() != getOperands().size())
1740 return emitOpError()
1741 .append(
"expected ", funType.getNumResults(),
" result operands")
1742 .attachNote(
function.getLoc())
1743 .append(
"return type declared here");
1746 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1747 auto [type, operand] = pair.value();
1748 if (type != operand.getType())
1749 return emitOpError() <<
"unexpected type `" << operand.getType()
1750 <<
"' for operand #" << pair.index();
1760 StringRef name, ArrayAttr targets,
1765 props.targets = targets;
1767 props.offloadingHandler = offloadingHandler;
1773 build(builder, result, name,
1774 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1778 bool GPUModuleOp::hasTarget(
Attribute target) {
1779 if (ArrayAttr targets = getTargetsAttr())
1780 return llvm::count(targets.getValue(), target);
1785 ArrayAttr &targetsAttr = getProperties().targets;
1794 Attribute offloadingHandler, ArrayAttr objects) {
1798 properties.objects = objects;
1799 if (offloadingHandler)
1800 properties.offloadingHandler = offloadingHandler;
1802 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1807 build(builder, result, name, offloadingHandler,
1808 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1819 if (!offloadingHandler)
1827 printer << '<' << offloadingHandler << '>
';
1830 //===----------------------------------------------------------------------===//
1832 //===----------------------------------------------------------------------===//
1834 LogicalResult MemcpyOp::verify() {
1835 auto srcType = getSrc().getType();
1836 auto dstType = getDst().getType();
1838 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1839 return emitOpError("arguments have incompatible element type");
1841 if (failed(verifyCompatibleShape(srcType, dstType)))
1842 return emitOpError("arguments have incompatible shape");
1851 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1852 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1854 LogicalResult matchAndRewrite(MemcpyOp op,
1855 PatternRewriter &rewriter) const override {
1856 Value dest = op.getDst();
1857 Operation *destDefOp = dest.getDefiningOp();
1858 // `dest` must be defined by an op having Allocate memory effect in order to
1859 // perform the folding.
1861 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1863 // We can erase `op` iff `dest` has no other use apart from its
1864 // use by `op` and dealloc ops.
1865 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1866 return user != op &&
1867 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1870 // We can perform the folding if and only if op has a single async
1871 // dependency and produces an async token as result, or if it does not have
1872 // any async dependency and does not produce any async token result.
1873 if (op.getAsyncDependencies().size() > 1 ||
1874 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1875 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1877 rewriter.replaceOp(op, op.getAsyncDependencies());
1882 } // end anonymous namespace
1884 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1885 MLIRContext *context) {
1886 results.add<EraseTrivialCopyOp>(context);
1889 //===----------------------------------------------------------------------===//
1890 // GPU_SubgroupMmaLoadMatrixOp
1891 //===----------------------------------------------------------------------===//
1893 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1894 auto srcType = getSrcMemref().getType();
1895 auto resType = getRes().getType();
1896 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1897 auto operand = resMatrixType.getOperand();
1898 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1900 if (!isLastMemrefDimUnitStride(srcMemrefType))
1902 "expected source memref most minor dim must have unit stride");
1904 if (operand != "AOp" && operand != "BOp" && operand != "COp")
1905 return emitError("only AOp, BOp and COp can be loaded");
1910 //===----------------------------------------------------------------------===//
1911 // GPU_SubgroupMmaStoreMatrixOp
1912 //===----------------------------------------------------------------------===//
1914 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1915 auto srcType = getSrc().getType();
1916 auto dstType = getDstMemref().getType();
1917 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1918 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1920 if (!isLastMemrefDimUnitStride(dstMemrefType))
1922 "expected destination memref most minor dim must have unit stride");
1924 if (srcMatrixType.getOperand() != "COp")
1926 "expected the operand matrix being stored to have 'COp
' operand type");
1931 //===----------------------------------------------------------------------===//
1932 // GPU_SubgroupMmaComputeOp
1933 //===----------------------------------------------------------------------===//
1935 LogicalResult SubgroupMmaComputeOp::verify() {
1936 enum OperandMap { A, B, C };
1937 SmallVector<MMAMatrixType, 3> opTypes;
1938 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1939 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1940 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1942 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1943 opTypes[C].getOperand() != "COp")
1944 return emitError("operands must be in the order AOp, BOp, COp");
1946 ArrayRef<int64_t> aShape, bShape, cShape;
1947 aShape = opTypes[A].getShape();
1948 bShape = opTypes[B].getShape();
1949 cShape = opTypes[C].getShape();
1951 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1952 bShape[1] != cShape[1])
1953 return emitError("operand shapes do not satisfy matmul constraints");
1958 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1959 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1960 return memref::foldMemRefCast(*this);
1963 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1964 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1965 return memref::foldMemRefCast(*this);
1968 //===----------------------------------------------------------------------===//
1970 //===----------------------------------------------------------------------===//
1977 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1979 using OpRewritePattern::OpRewritePattern;
1981 LogicalResult matchAndRewrite(WaitOp op,
1982 PatternRewriter &rewriter) const final {
1983 auto predicate = [](Value value) {
1984 auto waitOp = value.getDefiningOp<WaitOp>();
1985 return waitOp && waitOp->getNumOperands() == 0;
1987 if (llvm::none_of(op.getAsyncDependencies(), predicate))
1989 SmallVector<Value> validOperands;
1990 for (Value operand : op->getOperands()) {
1991 if (predicate(operand))
1993 validOperands.push_back(operand);
1995 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2007 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2009 using OpRewritePattern::OpRewritePattern;
2011 LogicalResult matchAndRewrite(WaitOp op,
2012 PatternRewriter &rewriter) const final {
2013 // Erase gpu.wait ops that neither have any async dependencies nor return
2015 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2016 rewriter.eraseOp(op);
2019 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2020 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2021 op.getAsyncToken()) {
2022 rewriter.replaceOp(op, op.getAsyncDependencies());
2025 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2026 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2027 rewriter.eraseOp(op);
2034 } // end anonymous namespace
2036 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2037 MLIRContext *context) {
2038 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2041 //===----------------------------------------------------------------------===//
2043 //===----------------------------------------------------------------------===//
2045 LogicalResult AllocOp::verify() {
2046 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2048 if (static_cast<int64_t>(getDynamicSizes().size()) !=
2049 memRefType.getNumDynamicDims())
2050 return emitOpError("dimension operand count does not equal memref "
2051 "dynamic dimension count");
2053 unsigned numSymbols = 0;
2054 if (!memRefType.getLayout().isIdentity())
2055 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2056 if (getSymbolOperands().size() != numSymbols) {
2058 "symbol operand count does not equal memref symbol count");
2068 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2069 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2071 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2072 PatternRewriter &rewriter) const override {
2073 std::optional<int64_t> index = dimOp.getConstantIndex();
2077 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2078 if (!memrefType || !memrefType.isDynamicDim(index.value()))
2081 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2085 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2086 memrefType.getDynamicDimIndex(index.value()));
2087 rewriter.replaceOp(dimOp, substituteOp);
2094 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2095 MLIRContext *context) {
2096 results.add<SimplifyDimOfAllocOp>(context);
2099 //===----------------------------------------------------------------------===//
2100 // GPU object attribute
2101 //===----------------------------------------------------------------------===//
2103 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2104 Attribute target, CompilationTarget format,
2105 StringAttr object, DictionaryAttr properties,
2106 KernelTableAttr kernels) {
2108 return emitError() << "the target attribute cannot be null";
2109 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2111 return emitError() << "the target attribute must implement or promise the "
2112 "`gpu::TargetAttrInterface`";
2116 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2117 StringAttr &object) {
2118 std::optional<CompilationTarget> formatResult;
2119 StringRef enumKeyword;
2120 auto loc = odsParser.getCurrentLocation();
2121 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2122 formatResult = CompilationTarget::Fatbin;
2123 if (!formatResult &&
2125 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2126 odsParser.parseEqual())
2127 return odsParser.emitError(loc, "expected an equal sign");
2129 return odsParser.emitError(loc, "expected keyword for GPU object format");
2130 FailureOr<StringAttr> objectResult =
2131 FieldParser<StringAttr>::parse(odsParser);
2132 if (failed(objectResult))
2133 return odsParser.emitError(odsParser.getCurrentLocation(),
2134 "failed to parse GPU_ObjectAttr parameter "
2135 "'
object' which is to be a `StringAttr`");
2136 format = *formatResult;
2137 object = *objectResult;
2141 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2142 StringAttr object) {
2143 if (format != CompilationTarget::Fatbin)
2144 odsParser << stringifyEnum(format) << " = ";
2145 odsParser << object;
2149 //===----------------------------------------------------------------------===//
2150 // GPU select object attribute
2151 //===----------------------------------------------------------------------===//
2154 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2156 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2158 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2159 if (intAttr.getInt() < 0) {
2160 return emitError() << "the object index must be positive";
2162 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2164 << "the target attribute must be a GPU Target attribute";
2170 //===----------------------------------------------------------------------===//
2171 // DynamicSharedMemoryOp
2172 //===----------------------------------------------------------------------===//
2174 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2175 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2176 return emitOpError() << "must be inside an op with symbol table";
2178 MemRefType memrefType = getResultMemref().getType();
2179 // Check address space
2180 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2181 return emitOpError() << "address space must be "
2182 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2183 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2185 if (memrefType.hasStaticShape()) {
2186 return emitOpError() << "result memref type must be memref<?xi8, "
2187 "#gpu.address_space<workgroup>>";
2192 //===----------------------------------------------------------------------===//
2193 // GPU KernelMetadataAttr
2194 //===----------------------------------------------------------------------===//
2196 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2197 DictionaryAttr metadata) {
2198 assert(kernel && "invalid kernel");
2199 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2200 kernel.getAllArgAttrs(), metadata);
2204 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2205 FunctionOpInterface kernel,
2206 DictionaryAttr metadata) {
2207 assert(kernel && "invalid kernel");
2208 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2209 kernel.getAllArgAttrs(), metadata);
2213 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2216 NamedAttrList attrList;
2217 if (DictionaryAttr dict = getMetadata())
2218 attrList.append(dict);
2219 attrList.append(attrs);
2220 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2221 attrList.getDictionary(getContext()));
2225 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2226 StringAttr name, Type functionType,
2227 ArrayAttr argAttrs, DictionaryAttr metadata) {
2229 return emitError() << "the kernel name can't be empty
";
2231 if (llvm::any_of(argAttrs, [](Attribute attr) {
2232 return !llvm::isa<DictionaryAttr>(attr);
2235 << "all attributes in the array must be a dictionary attribute
";
2240 //===----------------------------------------------------------------------===//
2241 // GPU KernelTableAttr
2242 //===----------------------------------------------------------------------===//
2244 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2245 ArrayRef<KernelMetadataAttr> kernels,
2247 // Note that `is_sorted` is always only invoked once even with assertions ON.
2248 assert((!isSorted || llvm::is_sorted(kernels)) &&
2249 "expected a sorted kernel array
");
2250 // Immediately return the attribute if the array is sorted.
2251 if (isSorted || llvm::is_sorted(kernels))
2252 return Base::get(context, kernels);
2254 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2255 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2256 return Base::get(context, kernelsTmp);
2259 KernelTableAttr KernelTableAttr::getChecked(
2260 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2261 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2262 // Note that `is_sorted` is always only invoked once even with assertions ON.
2263 assert((!isSorted || llvm::is_sorted(kernels)) &&
2264 "expected a sorted kernel array
");
2265 // Immediately return the attribute if the array is sorted.
2266 if (isSorted || llvm::is_sorted(kernels))
2267 return Base::getChecked(emitError, context, kernels);
2269 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2270 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2271 return Base::getChecked(emitError, context, kernelsTmp);
2275 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2276 ArrayRef<KernelMetadataAttr> kernels) {
2277 if (kernels.size() < 2)
2279 // Check that the kernels are uniquely named.
2280 if (std::adjacent_find(kernels.begin(), kernels.end(),
2281 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2282 return l.getName() == r.getName();
2283 }) != kernels.end()) {
2284 return emitError() << "expected all kernels to be uniquely named
";
2289 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2290 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2291 return found ? *iterator : KernelMetadataAttr();
2294 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2295 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2296 return found ? *iterator : KernelMetadataAttr();
2299 //===----------------------------------------------------------------------===//
2300 // GPU target options
2301 //===----------------------------------------------------------------------===//
2303 TargetOptions::TargetOptions(
2304 StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2305 StringRef cmdOptions, CompilationTarget compilationTarget,
2306 function_ref<SymbolTable *()> getSymbolTableCallback)
2307 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2308 cmdOptions, compilationTarget, getSymbolTableCallback) {}
2310 TargetOptions::TargetOptions(
2311 TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2312 StringRef cmdOptions, CompilationTarget compilationTarget,
2313 function_ref<SymbolTable *()> getSymbolTableCallback)
2314 : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2315 cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2316 getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2318 TypeID TargetOptions::getTypeID() const { return typeID; }
2320 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2322 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2324 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2326 SymbolTable *TargetOptions::getSymbolTable() const {
2327 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2330 CompilationTarget TargetOptions::getCompilationTarget() const {
2331 return compilationTarget;
2334 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2335 return CompilationTarget::Fatbin;
2338 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2339 TargetOptions::tokenizeCmdOptions() const {
2340 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2341 llvm::StringSaver stringSaver(options.first);
2342 StringRef opts = cmdOptions;
2343 // For a correct tokenization of the command line options `opts` must be
2344 // unquoted, otherwise the tokenization function returns a single string: the
2345 // unquoted `cmdOptions` -which is not the desired behavior.
2346 // Remove any quotes if they are at the beginning and end of the string:
2347 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2348 opts.consume_front("\
""), opts.consume_back(
"\"");
2349 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2350 opts.consume_front(
"'"), opts.consume_back(
"'");
2352 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2355 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2363 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2364 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2366 #define GET_ATTRDEF_CLASSES
2367 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2369 #define GET_OP_CLASSES
2370 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2372 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....