33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/StringSaver.h"
43 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
49 int64_t GPUBlockMappingAttr::getMappingId()
const {
50 return static_cast<int64_t
>(getBlock());
53 bool GPUBlockMappingAttr::isLinearMapping()
const {
54 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
57 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
58 return isLinearMapping()
59 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
63 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
64 return static_cast<int64_t
>(getWarpgroup());
67 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
68 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
71 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
72 return isLinearMapping()
73 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
77 int64_t GPUWarpMappingAttr::getMappingId()
const {
78 return static_cast<int64_t
>(getWarp());
81 bool GPUWarpMappingAttr::isLinearMapping()
const {
82 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
85 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
86 return isLinearMapping()
87 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
91 int64_t GPUThreadMappingAttr::getMappingId()
const {
92 return static_cast<int64_t
>(getThread());
95 bool GPUThreadMappingAttr::isLinearMapping()
const {
96 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
99 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
100 return isLinearMapping()
101 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
105 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
106 return static_cast<int64_t
>(getAddressSpace());
109 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
110 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
113 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
114 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
131 elementType, operand);
145 return elementType.
isF16() || elementType.
isF32() ||
154 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
155 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
157 if (shape.size() != 2)
158 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
162 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
171 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
174 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
175 return gpuAttr.getValue() == getWorkgroupAddressSpace();
179 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
180 Attribute memorySpace = type.getMemorySpace();
181 return isWorkgroupMemoryAddressSpace(memorySpace);
184 bool GPUDialect::isKernel(
Operation *op) {
185 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
186 return static_cast<bool>(isKernelAttr);
202 void GPUDialect::initialize() {
203 addTypes<AsyncTokenType>();
204 addTypes<MMAMatrixType>();
205 addTypes<SparseDnTensorHandleType>();
206 addTypes<SparseSpMatHandleType>();
207 addTypes<SparseSpGEMMOpHandleType>();
210 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
213 #define GET_ATTRDEF_LIST
214 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
216 addInterfaces<GPUInlinerInterface>();
217 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
224 return "sparse.dntensor_handle";
226 return "sparse.spmat_handle";
228 return "sparse.spgemmop_handle";
230 llvm_unreachable(
"unknown sparse handle kind");
242 if (keyword ==
"async.token")
245 if (keyword ==
"mma_matrix") {
274 shape, elementType, operand);
292 .Case<SparseDnTensorHandleType>([&](
Type) {
295 .Case<SparseSpMatHandleType>(
297 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
303 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
306 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
308 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
313 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
316 " must be a dense i32 array");
317 if (array.size() != 3)
319 " must contain exactly 3 elements");
323 LogicalResult GPUDialect::verifyOperationAttribute(
Operation *op,
325 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
327 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
329 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
330 attr.
getName() != getContainerModuleAttrName())
333 auto module = dyn_cast<ModuleOp>(op);
336 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
337 << ModuleOp::getOperationName() <<
'\'';
339 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
342 if (!launchOp->getParentOp() ||
343 launchOp->getParentOp()->getParentOp() != module)
348 if (!launchOp->getAttrOfType<SymbolRefAttr>(
349 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
353 StringAttr kernelContainerName = launchOp.getKernelModuleName();
354 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
355 if (!kernelContainer)
357 <<
"kernel container '" << kernelContainerName.getValue()
361 if (isa<BinaryOp>(kernelContainer))
364 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
366 return launchOp.emitOpError()
367 <<
"kernel module '" << kernelContainerName.getValue()
371 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
374 << launchOp.getKernel() <<
"' is undefined";
375 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
376 if (!kernelConvertedFunction) {
378 <<
"referenced kernel '" << launchOp.getKernel()
379 <<
"' is not a function";
380 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
385 GPUDialect::getKernelFuncAttrName()))
386 return launchOp.emitOpError(
"kernel function is missing the '")
387 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
392 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
393 if (!kernelGPUFunction)
396 unsigned actualNumArguments = launchOp.getNumKernelOperands();
397 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
398 if (expectedNumArguments != actualNumArguments)
399 return launchOp.emitOpError(
"got ")
400 << actualNumArguments <<
" kernel operands but expected "
401 << expectedNumArguments;
403 auto functionType = kernelGPUFunction.getFunctionType();
404 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
405 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
406 return launchOp.emitOpError(
"type of function argument ")
407 << i <<
" does not match";
414 return walkResult.wasInterrupted() ? failure() : success();
427 return parser.
emitError(loc,
"needs to be named when marked 'async'");
442 if (asyncDependencies.empty())
447 llvm::interleaveComma(asyncDependencies, printer);
476 p <<
' ' << keyword <<
'(';
477 llvm::interleaveComma(
485 gpu::AddressSpace memorySpace) {
486 for (
Value v : attributions) {
487 auto type = llvm::dyn_cast<MemRefType>(v.getType());
489 return op->
emitOpError() <<
"expected memref type in attribution";
494 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
497 if (addressSpace.getValue() != memorySpace)
499 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
500 <<
" in attribution";
511 using Kind = gpu::AllReduceOperation;
512 if (llvm::is_contained(
513 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
515 if (!isa<FloatType>(resType))
519 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
520 Kind::AND, Kind::OR, Kind::XOR},
522 if (!isa<IntegerType>(resType))
529 LogicalResult gpu::AllReduceOp::verifyRegions() {
530 if (getBody().empty() != getOp().has_value())
531 return emitError(
"expected either an op attribute or a non-empty body");
532 if (!getBody().empty()) {
533 if (getBody().getNumArguments() != 2)
534 return emitError(
"expected two region arguments");
535 for (
auto argument : getBody().getArguments()) {
536 if (argument.getType() !=
getType())
537 return emitError(
"incorrect region argument type");
539 unsigned yieldCount = 0;
540 for (
Block &block : getBody()) {
541 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
542 if (yield.getNumOperands() != 1)
543 return emitError(
"expected one gpu.yield operand");
544 if (yield.getOperand(0).getType() !=
getType())
545 return emitError(
"incorrect gpu.yield type");
550 return emitError(
"expected gpu.yield op in region");
552 gpu::AllReduceOperation opName = *getOp();
554 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
555 <<
"` reduction operation is not compatible with type "
564 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
568 Region &body = launchOp.getBody();
569 assert(!body.
empty() &&
"Invalid region");
586 AllReduceOperationAttr &attr) {
589 std::optional<AllReduceOperation> op =
590 gpu::symbolizeAllReduceOperation(enumStr);
599 AllReduceOperationAttr attr) {
610 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
611 if (vecTy.isScalable())
612 return emitOpError() <<
"is not compatible with scalable vector types";
614 elemType = vecTy.getElementType();
617 gpu::AllReduceOperation opName = getOp();
619 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
620 <<
"` reduction operation is not compatible with type "
626 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
641 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
645 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
663 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
672 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
681 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
682 getBlockSizeY, getBlockSizeZ});
689 if (dynamicSharedMemorySize)
698 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
701 for (
Type argTy : workgroupAttributions)
703 for (
Type argTy : privateAttributions)
707 segmentSizes.front() = asyncDependencies.size();
708 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
709 segmentSizes[7] = clusterSizeX ? 1 : 0;
710 segmentSizes[8] = clusterSizeY ? 1 : 0;
711 segmentSizes[9] = clusterSizeZ ? 1 : 0;
717 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
718 auto args = getBody().getArguments();
723 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
724 auto args = getBody().getArguments();
729 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
730 auto args = getBody().getArguments();
735 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
736 auto args = getBody().getArguments();
737 return KernelDim3{args[9], args[10], args[11]};
740 std::optional<KernelDim3> LaunchOp::getClusterIds() {
741 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
742 if (!hasClusterSize())
744 auto args = getBody().getArguments();
745 return KernelDim3{args[12], args[13], args[14]};
748 std::optional<KernelDim3> LaunchOp::getClusterSize() {
749 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
750 if (!hasClusterSize())
752 auto args = getBody().getArguments();
753 return KernelDim3{args[15], args[16], args[17]};
756 KernelDim3 LaunchOp::getGridSizeOperandValues() {
757 auto operands = getOperands().drop_front(getAsyncDependencies().size());
758 return KernelDim3{operands[0], operands[1], operands[2]};
761 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
762 auto operands = getOperands().drop_front(getAsyncDependencies().size());
763 return KernelDim3{operands[3], operands[4], operands[5]};
766 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
767 auto operands = getOperands().drop_front(getAsyncDependencies().size());
768 if (!hasClusterSize())
770 return KernelDim3{operands[6], operands[7], operands[8]};
774 if (!(hasClusterSize()) &&
775 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
776 return emitOpError() <<
"cluster size must be all present";
780 LogicalResult LaunchOp::verifyRegions() {
784 if (!getBody().empty()) {
785 if (getBody().getNumArguments() <
786 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
787 return emitOpError(
"unexpected number of region arguments");
792 GPUDialect::getWorkgroupAddressSpace())) ||
794 GPUDialect::getPrivateAddressSpace())))
799 for (
Block &block : getBody()) {
802 if (block.back().getNumSuccessors() != 0)
804 if (!isa<gpu::TerminatorOp>(&block.back())) {
807 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
808 "' or a terminator with successors")
809 .attachNote(getLoc())
810 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
814 if (getNumResults() == 0 && getAsyncToken())
815 return emitOpError(
"needs to be named when async keyword is specified");
826 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
827 p << size.
x <<
" = " << operands.
x <<
", ";
828 p << size.
y <<
" = " << operands.
y <<
", ";
829 p << size.
z <<
" = " << operands.
z <<
')';
833 if (getAsyncToken()) {
835 if (!getAsyncDependencies().empty())
836 p <<
" [" << getAsyncDependencies() <<
']';
839 if (hasClusterSize()) {
840 p <<
' ' << getClustersKeyword();
842 getClusterSizeOperandValues().value(),
843 getClusterIds().value());
845 p <<
' ' << getBlocksKeyword();
848 p <<
' ' << getThreadsKeyword();
851 if (getDynamicSharedMemorySize())
852 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
853 << getDynamicSharedMemorySize();
862 LaunchOp::getOperandSegmentSizeAttr(),
863 getNumWorkgroupAttributionsAttrName()});
877 assert(indices.size() == 3 &&
"space for three indices expected");
883 std::move(args.begin(), args.end(), indices.begin());
885 for (
int i = 0; i < 3; ++i) {
907 sizes(LaunchOp::kNumConfigOperands);
914 LaunchOp::kNumConfigRegionAttributes);
925 result.
types.push_back(asyncTokenType);
927 bool hasCluster =
false;
932 regionArgs.resize(18);
941 regionArgsRef.slice(15, 3),
942 regionArgsRef.slice(12, 3)))
950 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
952 regionArgsRef.slice(6, 3),
953 regionArgsRef.slice(0, 3)) ||
954 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
956 regionArgsRef.slice(9, 3),
957 regionArgsRef.slice(3, 3)) ||
963 bool hasDynamicSharedMemorySize =
false;
965 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
966 hasDynamicSharedMemorySize =
true;
982 LaunchOp::kNumConfigRegionAttributes + 6, index);
985 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
987 arg.
ssaName = std::get<0>(ssaValueAndType);
988 arg.
type = std::get<1>(ssaValueAndType);
989 regionArguments.push_back(arg);
1000 unsigned numWorkgroupAttrs = regionArguments.size() -
1001 LaunchOp::kNumConfigRegionAttributes -
1002 (hasCluster ? 6 : 0);
1003 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1020 segmentSizes.front() = asyncDependencies.size();
1023 segmentSizes[7] = 0;
1024 segmentSizes[8] = 0;
1025 segmentSizes[9] = 0;
1027 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1028 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1042 bool simplified =
false;
1043 auto constPropIdUses = [&](
Value id,
Value size) {
1047 if (
id.getUses().empty())
1054 rewriter.
create<arith::ConstantIndexOp>(op.
getLoc(), 0);
1059 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1060 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1061 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1062 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1063 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1064 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1066 return success(simplified);
1078 auto attrName = getNumWorkgroupAttributionsAttrName();
1079 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1080 (*this)->setAttr(attrName,
1082 return getBody().insertArgument(
1083 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1091 return getBody().addArgument(type, loc);
1099 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1103 std::optional<KernelDim3> clusterSize) {
1104 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1105 "expected a symbol reference with a single nested reference");
1113 if (clusterSize.has_value())
1114 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1115 if (dynamicSharedMemorySize)
1120 prop.kernel = kernelSymbol;
1121 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1123 for (
auto &sz : prop.operandSegmentSizes)
1125 prop.operandSegmentSizes[0] = asyncDependencies.size();
1126 if (!clusterSize.has_value()) {
1127 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1128 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1129 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1131 prop.operandSegmentSizes[segmentSizesLen - 3] =
1132 dynamicSharedMemorySize ? 1 : 0;
1133 prop.operandSegmentSizes[segmentSizesLen - 2] =
1134 static_cast<int32_t
>(kernelOperands.size());
1135 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1143 std::optional<KernelDim3> clusterSize) {
1144 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1147 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1148 build(builder, result, kernelSymbol, gridSize,
getBlockSize,
1149 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1150 asyncDependencies, clusterSize);
1157 std::optional<KernelDim3> clusterSize) {
1161 if (clusterSize.has_value())
1162 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1163 if (dynamicSharedMemorySize)
1169 prop.kernel = kernel;
1170 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1172 for (
auto &sz : prop.operandSegmentSizes)
1174 prop.operandSegmentSizes[0] = 0;
1175 if (!clusterSize.has_value()) {
1176 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1177 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1178 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1180 prop.operandSegmentSizes[segmentSizesLen - 3] =
1181 dynamicSharedMemorySize ? 1 : 0;
1182 prop.operandSegmentSizes[segmentSizesLen - 2] =
1183 static_cast<int32_t
>(kernelOperands.size());
1184 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1187 StringAttr LaunchFuncOp::getKernelModuleName() {
1191 StringAttr LaunchFuncOp::getKernelName() {
1195 unsigned LaunchFuncOp::getNumKernelOperands() {
1196 return getKernelOperands().size();
1199 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1200 return getKernelOperands()[i];
1203 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1204 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1205 return KernelDim3{operands[0], operands[1], operands[2]};
1208 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1209 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1210 return KernelDim3{operands[3], operands[4], operands[5]};
1213 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1214 assert(hasClusterSize() &&
1215 "cluster size is not set, check hasClusterSize() first");
1216 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1217 return KernelDim3{operands[6], operands[7], operands[8]};
1221 auto module = (*this)->getParentOfType<ModuleOp>();
1223 return emitOpError(
"expected to belong to a module");
1225 if (!module->getAttrOfType<UnitAttr>(
1226 GPUDialect::getContainerModuleAttrName()))
1227 return emitOpError(
"expected the closest surrounding module to have the '" +
1228 GPUDialect::getContainerModuleAttrName() +
1231 if (hasClusterSize()) {
1232 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1234 return emitOpError()
1235 <<
"expects types of the cluster dimensions must be the same";
1243 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1244 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1251 if (clusterValue.has_value()) {
1252 clusterXTy = clusterYTy = clusterZTy = dimTy;
1259 Type clusterYTy,
Type clusterZTy) {
1261 printer <<
": " << dimTy;
1271 auto parseElement = [&]() -> ParseResult {
1272 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1277 parseElement,
" in argument list");
1282 if (operands.empty())
1285 llvm::interleaveComma(llvm::zip(operands, types), printer,
1286 [&](
const auto &pair) {
1299 int32_t offset, int32_t width, ShuffleMode mode) {
1300 build(builder, result, value,
1315 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1317 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1328 results.
add(eraseRedundantGpuBarrierOps);
1338 auto attrName = getNumWorkgroupAttributionsAttrName();
1339 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1340 (*this)->setAttr(attrName,
1342 return getBody().insertArgument(
1343 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1351 return getBody().addArgument(type, loc);
1355 StringRef name, FunctionType type,
1365 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1372 for (
Type argTy : type.getInputs())
1374 for (
Type argTy : workgroupAttributions)
1376 for (
Type argTy : privateAttributions)
1395 size_t existingArgs = args.size();
1396 ParseResult result =
1402 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1407 attributionAttrs =
nullptr;
1413 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1414 if (!argument.attrs)
1417 attributionAttrsVec.push_back(argument.attrs);
1419 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1435 StringAttr nameAttr;
1442 parser,
false, entryArgs, isVariadic, resultTypes,
1446 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1447 return parser.
emitError(signatureLocation)
1448 <<
"gpu.func requires named arguments";
1455 for (
auto &arg : entryArgs)
1456 argTypes.push_back(arg.
type);
1462 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1463 getResAttrsAttrName(result.
name));
1468 entryArgs, workgroupAttributionAttrs)))
1473 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1474 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1476 if (workgroupAttributionAttrs)
1477 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1478 workgroupAttributionAttrs);
1483 entryArgs, privateAttributionAttrs)))
1485 if (privateAttributionAttrs)
1487 privateAttributionAttrs);
1491 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1506 ArrayAttr attributes) {
1510 p <<
' ' << keyword <<
'(';
1511 llvm::interleaveComma(
1514 p << v <<
" : " << v.
getType();
1516 size_t attributionIndex = pair.index();
1517 DictionaryAttr attrs;
1518 if (attributes && attributionIndex < attributes.size())
1519 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1530 FunctionType type = getFunctionType();
1536 getWorkgroupAttribAttrs().value_or(
nullptr));
1538 getPrivateAttribAttrs().value_or(
nullptr));
1540 p <<
' ' << getKernelKeyword();
1544 {getNumWorkgroupAttributionsAttrName(),
1545 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1546 getArgAttrsAttrName(), getResAttrsAttrName(),
1547 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1553 StringAttr attrName) {
1554 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1555 if (!allAttrs || index >= allAttrs.size())
1556 return DictionaryAttr();
1557 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1560 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1564 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1569 DictionaryAttr value, StringAttr attrName) {
1571 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->
getAttr(attrName));
1574 elements.append(allAttrs.begin(), allAttrs.end());
1575 while (elements.size() <= index)
1580 elements[index] = value;
1582 op->
setAttr(attrName, newValue);
1585 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1586 DictionaryAttr value) {
1590 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1591 DictionaryAttr value) {
1596 StringAttr name, StringAttr attrsName) {
1600 return dict.get(name);
1603 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1605 assert(index < getNumWorkgroupAttributions() &&
1606 "index must map to a workgroup attribution");
1608 getWorkgroupAttribAttrsAttrName());
1611 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1613 assert(index < getNumPrivateAttributions() &&
1614 "index must map to a private attribution");
1616 getPrivateAttribAttrsAttrName());
1620 Attribute value, StringAttr attrsName) {
1625 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1628 bool mustSort =
true;
1629 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1630 if (elems[i].getName() == name) {
1633 std::swap(elems[i], elems[elems.size() - 1]);
1645 elems.emplace_back(name, value);
1648 DictionaryAttr::sortInPlace(elems);
1650 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1654 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1656 assert(index < getNumWorkgroupAttributions() &&
1657 "index must map to a workgroup attribution");
1659 getWorkgroupAttribAttrsAttrName());
1662 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1664 assert(index < getNumPrivateAttributions() &&
1665 "index must map to a private attribution");
1667 getPrivateAttribAttrsAttrName());
1670 LogicalResult GPUFuncOp::verifyType() {
1671 if (isKernel() && getFunctionType().getNumResults() != 0)
1672 return emitOpError() <<
"expected void return type for kernel function";
1678 LogicalResult GPUFuncOp::verifyBody() {
1680 return emitOpError() <<
"expected body with at least one block";
1681 unsigned numFuncArguments = getNumArguments();
1682 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1683 unsigned numBlockArguments = front().getNumArguments();
1684 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1685 return emitOpError() <<
"expected at least "
1686 << numFuncArguments + numWorkgroupAttributions
1687 <<
" arguments to body region";
1690 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1691 Type blockArgType = front().getArgument(i).getType();
1692 if (funcArgTypes[i] != blockArgType)
1693 return emitOpError() <<
"expected body region argument #" << i
1694 <<
" to be of type " << funcArgTypes[i] <<
", got "
1699 GPUDialect::getWorkgroupAddressSpace())) ||
1701 GPUDialect::getPrivateAddressSpace())))
1712 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1714 FunctionType funType =
function.getFunctionType();
1716 if (funType.getNumResults() != getOperands().size())
1717 return emitOpError()
1718 .append(
"expected ", funType.getNumResults(),
" result operands")
1719 .attachNote(
function.getLoc())
1720 .append(
"return type declared here");
1723 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1724 auto [type, operand] = pair.value();
1725 if (type != operand.getType())
1726 return emitOpError() <<
"unexpected type `" << operand.getType()
1727 <<
"' for operand #" << pair.index();
1737 StringRef name, ArrayAttr targets,
1743 props.targets = targets;
1745 props.offloadingHandler = offloadingHandler;
1751 build(builder, result, name,
1752 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1757 StringAttr nameAttr;
1758 ArrayAttr targetsAttr;
1764 props.setSymName(nameAttr);
1778 if (failed(*targetsAttrResult)) {
1781 props.targets = targetsAttr;
1802 if (
Attribute attr = getOffloadingHandlerAttr()) {
1808 if (
Attribute attr = getTargetsAttr()) {
1815 {mlir::SymbolTable::getSymbolAttrName(),
1816 getTargetsAttrName(),
1817 getOffloadingHandlerAttrName()});
1823 bool GPUModuleOp::hasTarget(
Attribute target) {
1824 if (ArrayAttr targets = getTargetsAttr())
1825 return llvm::count(targets.getValue(), target);
1830 ArrayAttr &targetsAttr = getProperties().targets;
1839 Attribute offloadingHandler, ArrayAttr objects) {
1843 properties.objects = objects;
1844 if (offloadingHandler)
1845 properties.offloadingHandler = offloadingHandler;
1847 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1852 build(builder, result, name, offloadingHandler,
1853 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1864 if (!offloadingHandler)
1872 printer << '<' << offloadingHandler << '>
';
1875 //===----------------------------------------------------------------------===//
1877 //===----------------------------------------------------------------------===//
1879 LogicalResult MemcpyOp::verify() {
1880 auto srcType = getSrc().getType();
1881 auto dstType = getDst().getType();
1883 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1884 return emitOpError("arguments have incompatible element type");
1886 if (failed(verifyCompatibleShape(srcType, dstType)))
1887 return emitOpError("arguments have incompatible shape");
1896 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1897 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1899 LogicalResult matchAndRewrite(MemcpyOp op,
1900 PatternRewriter &rewriter) const override {
1901 Value dest = op.getDst();
1902 Operation *destDefOp = dest.getDefiningOp();
1903 // `dest` must be defined by an op having Allocate memory effect in order to
1904 // perform the folding.
1906 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1908 // We can erase `op` iff `dest` has no other use apart from its
1909 // use by `op` and dealloc ops.
1910 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1911 return user != op &&
1912 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1915 // We can perform the folding if and only if op has a single async
1916 // dependency and produces an async token as result, or if it does not have
1917 // any async dependency and does not produce any async token result.
1918 if (op.getAsyncDependencies().size() > 1 ||
1919 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1920 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1922 rewriter.replaceOp(op, op.getAsyncDependencies());
1927 } // end anonymous namespace
1929 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1930 MLIRContext *context) {
1931 results.add<EraseTrivialCopyOp>(context);
1934 //===----------------------------------------------------------------------===//
1935 // GPU_SubgroupMmaLoadMatrixOp
1936 //===----------------------------------------------------------------------===//
1938 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1939 auto srcType = getSrcMemref().getType();
1940 auto resType = getRes().getType();
1941 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1942 auto operand = resMatrixType.getOperand();
1943 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1945 if (!isLastMemrefDimUnitStride(srcMemrefType))
1947 "expected source memref most minor dim must have unit stride");
1949 if (operand != "AOp" && operand != "BOp" && operand != "COp")
1950 return emitError("only AOp, BOp and COp can be loaded");
1955 //===----------------------------------------------------------------------===//
1956 // GPU_SubgroupMmaStoreMatrixOp
1957 //===----------------------------------------------------------------------===//
1959 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1960 auto srcType = getSrc().getType();
1961 auto dstType = getDstMemref().getType();
1962 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1963 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1965 if (!isLastMemrefDimUnitStride(dstMemrefType))
1967 "expected destination memref most minor dim must have unit stride");
1969 if (srcMatrixType.getOperand() != "COp")
1971 "expected the operand matrix being stored to have 'COp
' operand type");
1976 //===----------------------------------------------------------------------===//
1977 // GPU_SubgroupMmaComputeOp
1978 //===----------------------------------------------------------------------===//
1980 LogicalResult SubgroupMmaComputeOp::verify() {
1981 enum OperandMap { A, B, C };
1982 SmallVector<MMAMatrixType, 3> opTypes;
1983 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1984 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1985 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1987 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1988 opTypes[C].getOperand() != "COp")
1989 return emitError("operands must be in the order AOp, BOp, COp");
1991 ArrayRef<int64_t> aShape, bShape, cShape;
1992 aShape = opTypes[A].getShape();
1993 bShape = opTypes[B].getShape();
1994 cShape = opTypes[C].getShape();
1996 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1997 bShape[1] != cShape[1])
1998 return emitError("operand shapes do not satisfy matmul constraints");
2003 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2004 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2005 return memref::foldMemRefCast(*this);
2008 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2009 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2010 return memref::foldMemRefCast(*this);
2013 //===----------------------------------------------------------------------===//
2015 //===----------------------------------------------------------------------===//
2022 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2024 using OpRewritePattern::OpRewritePattern;
2026 LogicalResult matchAndRewrite(WaitOp op,
2027 PatternRewriter &rewriter) const final {
2028 auto predicate = [](Value value) {
2029 auto waitOp = value.getDefiningOp<WaitOp>();
2030 return waitOp && waitOp->getNumOperands() == 0;
2032 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2034 SmallVector<Value> validOperands;
2035 for (Value operand : op->getOperands()) {
2036 if (predicate(operand))
2038 validOperands.push_back(operand);
2040 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2052 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2054 using OpRewritePattern::OpRewritePattern;
2056 LogicalResult matchAndRewrite(WaitOp op,
2057 PatternRewriter &rewriter) const final {
2058 // Erase gpu.wait ops that neither have any async dependencies nor return
2060 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2061 rewriter.eraseOp(op);
2064 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2065 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2066 op.getAsyncToken()) {
2067 rewriter.replaceOp(op, op.getAsyncDependencies());
2070 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2071 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2072 rewriter.eraseOp(op);
2079 } // end anonymous namespace
2081 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2082 MLIRContext *context) {
2083 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2086 //===----------------------------------------------------------------------===//
2088 //===----------------------------------------------------------------------===//
2090 LogicalResult AllocOp::verify() {
2091 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2093 if (static_cast<int64_t>(getDynamicSizes().size()) !=
2094 memRefType.getNumDynamicDims())
2095 return emitOpError("dimension operand count does not equal memref "
2096 "dynamic dimension count");
2098 unsigned numSymbols = 0;
2099 if (!memRefType.getLayout().isIdentity())
2100 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2101 if (getSymbolOperands().size() != numSymbols) {
2103 "symbol operand count does not equal memref symbol count");
2113 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2114 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2116 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2117 PatternRewriter &rewriter) const override {
2118 std::optional<int64_t> index = dimOp.getConstantIndex();
2122 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2123 if (!memrefType || !memrefType.isDynamicDim(index.value()))
2126 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2130 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2131 memrefType.getDynamicDimIndex(index.value()));
2132 rewriter.replaceOp(dimOp, substituteOp);
2139 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2140 MLIRContext *context) {
2141 results.add<SimplifyDimOfAllocOp>(context);
2144 //===----------------------------------------------------------------------===//
2145 // GPU object attribute
2146 //===----------------------------------------------------------------------===//
2148 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2149 Attribute target, CompilationTarget format,
2150 StringAttr object, DictionaryAttr properties) {
2152 return emitError() << "the target attribute cannot be null";
2153 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2155 return emitError() << "the target attribute must implement or promise the "
2156 "`gpu::TargetAttrInterface`";
2160 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2161 StringAttr &object) {
2162 std::optional<CompilationTarget> formatResult;
2163 StringRef enumKeyword;
2164 auto loc = odsParser.getCurrentLocation();
2165 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2166 formatResult = CompilationTarget::Fatbin;
2167 if (!formatResult &&
2169 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2170 odsParser.parseEqual())
2171 return odsParser.emitError(loc, "expected an equal sign");
2173 return odsParser.emitError(loc, "expected keyword for GPU object format");
2174 FailureOr<StringAttr> objectResult =
2175 FieldParser<StringAttr>::parse(odsParser);
2176 if (failed(objectResult))
2177 return odsParser.emitError(odsParser.getCurrentLocation(),
2178 "failed to parse GPU_ObjectAttr parameter "
2179 "'
object' which is to be a `StringAttr`");
2180 format = *formatResult;
2181 object = *objectResult;
2185 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2186 StringAttr object) {
2187 if (format != CompilationTarget::Fatbin)
2188 odsParser << stringifyEnum(format) << " = ";
2189 odsParser << object;
2193 //===----------------------------------------------------------------------===//
2194 // GPU select object attribute
2195 //===----------------------------------------------------------------------===//
2198 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2200 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2202 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2203 if (intAttr.getInt() < 0) {
2204 return emitError() << "the object index must be positive";
2206 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2208 << "the target attribute must be a GPU Target attribute";
2214 //===----------------------------------------------------------------------===//
2215 // DynamicSharedMemoryOp
2216 //===----------------------------------------------------------------------===//
2218 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2219 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2220 return emitOpError() << "must be inside an op with symbol table";
2222 MemRefType memrefType = getResultMemref().getType();
2223 // Check address space
2224 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2225 return emitOpError() << "address space must be "
2226 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2227 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2229 if (memrefType.hasStaticShape()) {
2230 return emitOpError() << "result memref type must be memref<?xi8, "
2231 "#gpu.address_space<workgroup>>";
2236 //===----------------------------------------------------------------------===//
2237 // GPU target options
2238 //===----------------------------------------------------------------------===//
2240 TargetOptions::TargetOptions(
2241 StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2242 StringRef cmdOptions, CompilationTarget compilationTarget,
2243 function_ref<SymbolTable *()> getSymbolTableCallback)
2244 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2245 cmdOptions, compilationTarget, getSymbolTableCallback) {}
2247 TargetOptions::TargetOptions(
2248 TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2249 StringRef cmdOptions, CompilationTarget compilationTarget,
2250 function_ref<SymbolTable *()> getSymbolTableCallback)
2251 : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2252 cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2253 getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2255 TypeID TargetOptions::getTypeID() const { return typeID; }
2257 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2259 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2261 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2263 SymbolTable *TargetOptions::getSymbolTable() const {
2264 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2267 CompilationTarget TargetOptions::getCompilationTarget() const {
2268 return compilationTarget;
2271 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2272 return CompilationTarget::Fatbin;
2275 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2276 TargetOptions::tokenizeCmdOptions() const {
2277 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2278 llvm::StringSaver stringSaver(options.first);
2279 StringRef opts = cmdOptions;
2280 // For a correct tokenization of the command line options `opts` must be
2281 // unquoted, otherwise the tokenization function returns a single string: the
2282 // unquoted `cmdOptions` -which is not the desired behavior.
2283 // Remove any quotes if they are at the beginning and end of the string:
2284 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2285 opts.consume_front("\""), opts.consume_back("\"");
2286 if (!opts.empty() && opts.front() == '\
'' && opts.back() ==
'\'')
2287 opts.consume_front(
"'"), opts.consume_back(
"'");
2289 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2292 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2300 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2301 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2303 #define GET_ATTRDEF_CLASSES
2304 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2306 #define GET_OP_CLASSES
2307 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2309 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....