35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/ErrorHandling.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/StringSaver.h"
48 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
54 int64_t GPUBlockMappingAttr::getMappingId()
const {
55 return static_cast<int64_t
>(getBlock());
58 bool GPUBlockMappingAttr::isLinearMapping()
const {
59 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
62 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
63 return isLinearMapping()
64 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
68 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
69 return static_cast<int64_t
>(getWarpgroup());
72 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
73 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
76 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
77 return isLinearMapping()
78 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
82 int64_t GPUWarpMappingAttr::getMappingId()
const {
83 return static_cast<int64_t
>(getWarp());
86 bool GPUWarpMappingAttr::isLinearMapping()
const {
87 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
90 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
91 return isLinearMapping()
92 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
96 int64_t GPUThreadMappingAttr::getMappingId()
const {
97 return static_cast<int64_t
>(getThread());
100 bool GPUThreadMappingAttr::isLinearMapping()
const {
101 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
104 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
105 return isLinearMapping()
106 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
110 int64_t GPULaneMappingAttr::getMappingId()
const {
111 return static_cast<int64_t
>(getLane());
114 bool GPULaneMappingAttr::isLinearMapping()
const {
115 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
118 int64_t GPULaneMappingAttr::getRelativeIndex()
const {
119 return isLinearMapping()
120 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
124 int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
136 Value GPUMappingMaskAttr::createLogicalLinearMappingId(
142 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
143 filter = arith::SubIOp::create(b, loc, filter, one);
144 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
145 return math::CtPopOp::create(b, loc, filteredId);
158 Value GPUMappingMaskAttr::createIsActiveIdPredicate(
164 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
165 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
167 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
171 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
172 return static_cast<int64_t
>(getAddressSpace());
175 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
176 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
179 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
180 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
197 elementType, operand);
211 return elementType.
isF16() || elementType.
isF32() ||
220 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
221 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
223 if (shape.size() != 2)
224 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
228 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
237 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
240 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
241 return gpuAttr.getValue() == getWorkgroupAddressSpace();
245 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
246 Attribute memorySpace = type.getMemorySpace();
247 return isWorkgroupMemoryAddressSpace(memorySpace);
250 bool GPUDialect::isKernel(
Operation *op) {
251 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
252 return static_cast<bool>(isKernelAttr);
268 void GPUDialect::initialize() {
269 addTypes<AsyncTokenType>();
270 addTypes<MMAMatrixType>();
271 addTypes<SparseDnTensorHandleType>();
272 addTypes<SparseSpMatHandleType>();
273 addTypes<SparseSpGEMMOpHandleType>();
276 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
279 #define GET_ATTRDEF_LIST
280 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
282 addInterfaces<GPUInlinerInterface>();
283 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
285 declarePromisedInterfaces<
286 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
287 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
288 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
294 return "sparse.dntensor_handle";
296 return "sparse.spmat_handle";
298 return "sparse.spgemmop_handle";
300 llvm_unreachable(
"unknown sparse handle kind");
312 if (keyword ==
"async.token")
315 if (keyword ==
"mma_matrix") {
344 shape, elementType, operand);
362 .Case<SparseDnTensorHandleType>([&](
Type) {
365 .Case<SparseSpMatHandleType>(
367 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
373 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
376 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
378 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
383 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
386 " must be a dense i32 array");
387 if (array.size() != 3)
389 " must contain exactly 3 elements");
393 LogicalResult GPUDialect::verifyOperationAttribute(
Operation *op,
395 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
397 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
399 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
400 attr.
getName() != getContainerModuleAttrName())
403 auto module = dyn_cast<ModuleOp>(op);
406 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
407 << ModuleOp::getOperationName() <<
'\'';
409 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
412 if (!launchOp->getParentOp() ||
413 launchOp->getParentOp()->getParentOp() != module)
418 if (!launchOp->getAttrOfType<SymbolRefAttr>(
419 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
423 StringAttr kernelContainerName = launchOp.getKernelModuleName();
424 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
425 if (!kernelContainer)
427 <<
"kernel container '" << kernelContainerName.getValue()
431 if (isa<BinaryOp>(kernelContainer))
434 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
436 return launchOp.emitOpError()
437 <<
"kernel module '" << kernelContainerName.getValue()
441 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
444 << launchOp.getKernel() <<
"' is undefined";
445 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
446 if (!kernelConvertedFunction) {
448 <<
"referenced kernel '" << launchOp.getKernel()
449 <<
"' is not a function";
450 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
455 GPUDialect::getKernelFuncAttrName()))
456 return launchOp.emitOpError(
"kernel function is missing the '")
457 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
462 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
463 if (!kernelGPUFunction)
466 unsigned actualNumArguments = launchOp.getNumKernelOperands();
467 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
468 if (expectedNumArguments != actualNumArguments)
469 return launchOp.emitOpError(
"got ")
470 << actualNumArguments <<
" kernel operands but expected "
471 << expectedNumArguments;
473 auto functionType = kernelGPUFunction.getFunctionType();
474 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
475 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
476 return launchOp.emitOpError(
"type of function argument ")
477 << i <<
" does not match";
484 return walkResult.wasInterrupted() ? failure() : success();
497 return parser.
emitError(loc,
"needs to be named when marked 'async'");
512 if (asyncDependencies.empty())
516 printer << llvm::interleaved_array(asyncDependencies);
545 return llvm::formatv(
"{} : {}", v, v.getType());
547 p <<
' ' << keyword <<
'('
548 << llvm::interleaved(llvm::map_range(values, printBlockArg)) <<
')';
554 gpu::AddressSpace memorySpace) {
555 for (
Value v : attributions) {
556 auto type = llvm::dyn_cast<MemRefType>(v.getType());
558 return op->
emitOpError() <<
"expected memref type in attribution";
563 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
566 if (addressSpace.getValue() != memorySpace)
568 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
569 <<
" in attribution";
580 using Kind = gpu::AllReduceOperation;
581 if (llvm::is_contained(
582 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
584 if (!isa<FloatType>(resType))
588 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
589 Kind::AND, Kind::OR, Kind::XOR},
591 if (!isa<IntegerType>(resType))
598 LogicalResult gpu::AllReduceOp::verifyRegions() {
599 if (getBody().empty() != getOp().has_value())
600 return emitError(
"expected either an op attribute or a non-empty body");
601 if (!getBody().empty()) {
602 if (getBody().getNumArguments() != 2)
603 return emitError(
"expected two region arguments");
604 for (
auto argument : getBody().getArguments()) {
605 if (argument.getType() !=
getType())
606 return emitError(
"incorrect region argument type");
608 unsigned yieldCount = 0;
609 for (
Block &block : getBody()) {
610 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
611 if (yield.getNumOperands() != 1)
612 return emitError(
"expected one gpu.yield operand");
613 if (yield.getOperand(0).getType() !=
getType())
614 return emitError(
"incorrect gpu.yield type");
619 return emitError(
"expected gpu.yield op in region");
621 gpu::AllReduceOperation opName = *getOp();
623 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
624 <<
"` reduction operation is not compatible with type "
633 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
637 Region &body = launchOp.getBody();
638 assert(!body.
empty() &&
"Invalid region");
655 AllReduceOperationAttr &attr) {
658 std::optional<AllReduceOperation> op =
659 gpu::symbolizeAllReduceOperation(enumStr);
668 AllReduceOperationAttr attr) {
679 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
680 if (vecTy.isScalable())
681 return emitOpError() <<
"is not compatible with scalable vector types";
683 elemType = vecTy.getElementType();
686 gpu::AllReduceOperation opName = getOp();
688 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
689 <<
"` reduction operation is not compatible with type "
693 auto clusterSize = getClusterSize();
695 uint32_t size = *clusterSize;
696 if (!llvm::isPowerOf2_32(size)) {
697 return emitOpError() <<
"cluster size " << size
698 <<
" is not a power of two";
702 uint32_t stride = getClusterStride();
703 if (stride != 1 && !clusterSize) {
704 return emitOpError() <<
"cluster stride can only be specified if cluster "
707 if (!llvm::isPowerOf2_32(stride)) {
708 return emitOpError() <<
"cluster stride " << stride
709 <<
" is not a power of two";
715 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
716 if (getClusterSize() == 1)
733 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
737 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
755 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
765 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
774 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
775 getBlockSizeY, getBlockSizeZ});
782 if (dynamicSharedMemorySize)
797 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
800 for (
Type argTy : workgroupAttributions)
802 for (
Type argTy : privateAttributions)
806 segmentSizes.front() = asyncDependencies.size();
807 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
808 segmentSizes[7] = clusterSizeX ? 1 : 0;
809 segmentSizes[8] = clusterSizeY ? 1 : 0;
810 segmentSizes[9] = clusterSizeZ ? 1 : 0;
816 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
817 auto args = getBody().getArguments();
822 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
823 auto args = getBody().getArguments();
828 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
829 auto args = getBody().getArguments();
834 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
835 auto args = getBody().getArguments();
836 return KernelDim3{args[9], args[10], args[11]};
839 std::optional<KernelDim3> LaunchOp::getClusterIds() {
840 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
841 if (!hasClusterSize())
843 auto args = getBody().getArguments();
844 return KernelDim3{args[12], args[13], args[14]};
847 std::optional<KernelDim3> LaunchOp::getClusterSize() {
848 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
849 if (!hasClusterSize())
851 auto args = getBody().getArguments();
852 return KernelDim3{args[15], args[16], args[17]};
855 KernelDim3 LaunchOp::getGridSizeOperandValues() {
856 auto operands = getOperands().drop_front(getAsyncDependencies().size());
857 return KernelDim3{operands[0], operands[1], operands[2]};
860 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
861 auto operands = getOperands().drop_front(getAsyncDependencies().size());
862 return KernelDim3{operands[3], operands[4], operands[5]};
865 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
866 auto operands = getOperands().drop_front(getAsyncDependencies().size());
867 if (!hasClusterSize())
869 return KernelDim3{operands[6], operands[7], operands[8]};
873 if (!(hasClusterSize()) &&
874 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
875 return emitOpError() <<
"cluster size must be all present";
879 LogicalResult LaunchOp::verifyRegions() {
883 if (!getBody().empty()) {
884 if (getBody().getNumArguments() <
885 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
886 return emitOpError(
"unexpected number of region arguments");
891 GPUDialect::getWorkgroupAddressSpace())) ||
893 GPUDialect::getPrivateAddressSpace())))
898 for (
Block &block : getBody()) {
901 if (block.back().getNumSuccessors() != 0)
903 if (!isa<gpu::TerminatorOp>(&block.back())) {
906 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
907 "' or a terminator with successors")
908 .attachNote(getLoc())
909 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
913 if (getNumResults() == 0 && getAsyncToken())
914 return emitOpError(
"needs to be named when async keyword is specified");
925 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
926 p << size.
x <<
" = " << operands.
x <<
", ";
927 p << size.
y <<
" = " << operands.
y <<
", ";
928 p << size.
z <<
" = " << operands.
z <<
')';
932 if (getAsyncToken()) {
934 if (!getAsyncDependencies().empty())
935 p <<
" [" << getAsyncDependencies() <<
']';
938 if (hasClusterSize()) {
939 p <<
' ' << getClustersKeyword();
941 getClusterSizeOperandValues().value(),
942 getClusterIds().value());
944 p <<
' ' << getBlocksKeyword();
947 p <<
' ' << getThreadsKeyword();
950 if (getDynamicSharedMemorySize())
951 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
952 << getDynamicSharedMemorySize();
955 StringRef moduleAttrName = getModuleAttrName();
956 if (
auto module = getModule()) {
957 p <<
' ' << moduleAttrName <<
'(';
962 StringRef functionAttrName = getFunctionAttrName();
963 if (
auto function = getFunction()) {
964 p <<
' ' << functionAttrName <<
'(';
976 LaunchOp::getOperandSegmentSizeAttr(),
977 getNumWorkgroupAttributionsAttrName(),
978 moduleAttrName, functionAttrName});
992 assert(indices.size() == 3 &&
"space for three indices expected");
998 std::move(args.begin(), args.end(), indices.begin());
1000 for (
int i = 0; i < 3; ++i) {
1025 sizes(LaunchOp::kNumConfigOperands);
1029 LaunchOp::kNumConfigRegionAttributes);
1033 Type asyncTokenType;
1040 result.
types.push_back(asyncTokenType);
1042 bool hasCluster =
false;
1047 regionArgs.resize(18);
1056 regionArgsRef.slice(15, 3),
1057 regionArgsRef.slice(12, 3)))
1065 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1067 regionArgsRef.slice(6, 3),
1068 regionArgsRef.slice(0, 3)) ||
1069 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1071 regionArgsRef.slice(9, 3),
1072 regionArgsRef.slice(3, 3)) ||
1078 bool hasDynamicSharedMemorySize =
false;
1080 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1081 hasDynamicSharedMemorySize =
true;
1090 StringRef moduleAttrName = getModuleAttrName(result.
name);
1100 StringRef functionAttrName = getFunctionAttrName(result.
name);
1118 LaunchOp::kNumConfigRegionAttributes + 6, index);
1121 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1123 arg.
ssaName = std::get<0>(ssaValueAndType);
1124 arg.
type = std::get<1>(ssaValueAndType);
1125 regionArguments.push_back(arg);
1136 unsigned numWorkgroupAttrs = regionArguments.size() -
1137 LaunchOp::kNumConfigRegionAttributes -
1138 (hasCluster ? 6 : 0);
1139 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1156 segmentSizes.front() = asyncDependencies.size();
1159 segmentSizes[7] = 0;
1160 segmentSizes[8] = 0;
1161 segmentSizes[9] = 0;
1163 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1164 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1178 bool simplified =
false;
1179 auto constPropIdUses = [&](
Value id,
Value size) {
1183 if (
id.getUses().empty())
1195 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1196 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1197 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1198 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1199 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1200 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1202 return success(simplified);
1214 auto attrName = getNumWorkgroupAttributionsAttrName();
1215 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1216 (*this)->setAttr(attrName,
1218 return getBody().insertArgument(
1219 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1227 return getBody().addArgument(type, loc);
1235 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1239 std::optional<KernelDim3> clusterSize) {
1240 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1241 "expected a symbol reference with a single nested reference");
1249 if (clusterSize.has_value())
1250 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1251 if (dynamicSharedMemorySize)
1256 prop.kernel = kernelSymbol;
1257 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1259 llvm::fill(prop.operandSegmentSizes, 1);
1260 prop.operandSegmentSizes[0] = asyncDependencies.size();
1261 if (!clusterSize.has_value()) {
1262 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1263 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1264 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1266 prop.operandSegmentSizes[segmentSizesLen - 3] =
1267 dynamicSharedMemorySize ? 1 : 0;
1268 prop.operandSegmentSizes[segmentSizesLen - 2] =
1269 static_cast<int32_t
>(kernelOperands.size());
1270 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1278 std::optional<KernelDim3> clusterSize) {
1279 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1282 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1283 build(builder, result, kernelSymbol, gridSize,
getBlockSize,
1284 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1285 asyncDependencies, clusterSize);
1292 std::optional<KernelDim3> clusterSize) {
1296 if (clusterSize.has_value())
1297 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1298 if (dynamicSharedMemorySize)
1304 prop.kernel = kernel;
1305 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1307 llvm::fill(prop.operandSegmentSizes, 1);
1308 prop.operandSegmentSizes[0] = 0;
1309 if (!clusterSize.has_value()) {
1310 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1311 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1312 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1314 prop.operandSegmentSizes[segmentSizesLen - 3] =
1315 dynamicSharedMemorySize ? 1 : 0;
1316 prop.operandSegmentSizes[segmentSizesLen - 2] =
1317 static_cast<int32_t
>(kernelOperands.size());
1318 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1321 StringAttr LaunchFuncOp::getKernelModuleName() {
1325 StringAttr LaunchFuncOp::getKernelName() {
1329 unsigned LaunchFuncOp::getNumKernelOperands() {
1330 return getKernelOperands().size();
1333 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1334 return getKernelOperands()[i];
1337 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1338 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1339 return KernelDim3{operands[0], operands[1], operands[2]};
1342 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1343 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1344 return KernelDim3{operands[3], operands[4], operands[5]};
1347 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1348 assert(hasClusterSize() &&
1349 "cluster size is not set, check hasClusterSize() first");
1350 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1351 return KernelDim3{operands[6], operands[7], operands[8]};
1355 auto module = (*this)->getParentOfType<ModuleOp>();
1357 return emitOpError(
"expected to belong to a module");
1359 if (!module->getAttrOfType<UnitAttr>(
1360 GPUDialect::getContainerModuleAttrName()))
1361 return emitOpError(
"expected the closest surrounding module to have the '" +
1362 GPUDialect::getContainerModuleAttrName() +
1365 if (hasClusterSize()) {
1366 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1368 return emitOpError()
1369 <<
"expects types of the cluster dimensions must be the same";
1377 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1378 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1385 if (clusterValue.has_value()) {
1386 clusterXTy = clusterYTy = clusterZTy = dimTy;
1393 Type clusterYTy,
Type clusterZTy) {
1395 printer <<
": " << dimTy;
1405 auto parseElement = [&]() -> ParseResult {
1406 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1411 parseElement,
" in argument list");
1416 if (operands.empty())
1419 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1420 [&](
const auto &pair) {
1421 auto [operand, type] = pair;
1422 printer << operand <<
" : " << type;
1432 int32_t offset, int32_t width, ShuffleMode mode) {
1433 build(builder, result, value,
1434 arith::ConstantOp::create(builder, result.
location,
1436 arith::ConstantOp::create(builder, result.
location,
1446 uint32_t offset = getOffset();
1447 uint32_t width = getWidth();
1449 if (offset >= width) {
1450 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1463 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1465 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1476 results.
add(eraseRedundantGpuBarrierOps);
1486 auto attrName = getNumWorkgroupAttributionsAttrName();
1487 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1488 (*this)->setAttr(attrName,
1490 return getBody().insertArgument(
1491 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1499 return getBody().addArgument(type, loc);
1503 StringRef name, FunctionType type,
1513 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1520 for (
Type argTy : type.getInputs())
1522 for (
Type argTy : workgroupAttributions)
1524 for (
Type argTy : privateAttributions)
1543 size_t existingArgs = args.size();
1544 ParseResult result =
1550 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1555 attributionAttrs =
nullptr;
1561 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1562 if (!argument.attrs)
1565 attributionAttrsVec.push_back(argument.attrs);
1567 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1583 StringAttr nameAttr;
1590 parser,
false, entryArgs, isVariadic, resultTypes,
1594 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1595 return parser.
emitError(signatureLocation)
1596 <<
"gpu.func requires named arguments";
1603 for (
auto &arg : entryArgs)
1604 argTypes.push_back(arg.
type);
1610 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1611 getResAttrsAttrName(result.
name));
1616 entryArgs, workgroupAttributionAttrs)))
1621 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1622 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1624 if (workgroupAttributionAttrs)
1625 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1626 workgroupAttributionAttrs);
1631 entryArgs, privateAttributionAttrs)))
1633 if (privateAttributionAttrs)
1635 privateAttributionAttrs);
1639 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1654 ArrayAttr attributes) {
1658 p <<
' ' << keyword <<
'(';
1659 llvm::interleaveComma(
1662 p << v <<
" : " << v.
getType();
1664 size_t attributionIndex = pair.index();
1665 DictionaryAttr attrs;
1666 if (attributes && attributionIndex < attributes.size())
1667 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1678 FunctionType type = getFunctionType();
1684 getWorkgroupAttribAttrs().value_or(
nullptr));
1686 getPrivateAttribAttrs().value_or(
nullptr));
1688 p <<
' ' << getKernelKeyword();
1692 {getNumWorkgroupAttributionsAttrName(),
1693 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1694 getArgAttrsAttrName(), getResAttrsAttrName(),
1695 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1701 StringAttr attrName) {
1702 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1703 if (!allAttrs || index >= allAttrs.size())
1704 return DictionaryAttr();
1705 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1708 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1712 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1717 DictionaryAttr value, StringAttr attrName) {
1719 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1722 elements.append(allAttrs.begin(), allAttrs.end());
1723 while (elements.size() <= index)
1728 elements[index] = value;
1730 op->setAttr(attrName, newValue);
1733 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1734 DictionaryAttr value) {
1738 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1739 DictionaryAttr value) {
1744 StringAttr name, StringAttr attrsName) {
1748 return dict.get(name);
1751 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1753 assert(index < getNumWorkgroupAttributions() &&
1754 "index must map to a workgroup attribution");
1756 getWorkgroupAttribAttrsAttrName());
1759 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1761 assert(index < getNumPrivateAttributions() &&
1762 "index must map to a private attribution");
1764 getPrivateAttribAttrsAttrName());
1768 Attribute value, StringAttr attrsName) {
1773 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1776 bool mustSort =
true;
1777 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1778 if (elems[i].getName() == name) {
1781 std::swap(elems[i], elems[elems.size() - 1]);
1793 elems.emplace_back(name, value);
1796 DictionaryAttr::sortInPlace(elems);
1798 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1802 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1804 assert(index < getNumWorkgroupAttributions() &&
1805 "index must map to a workgroup attribution");
1807 getWorkgroupAttribAttrsAttrName());
1810 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1812 assert(index < getNumPrivateAttributions() &&
1813 "index must map to a private attribution");
1815 getPrivateAttribAttrsAttrName());
1818 LogicalResult GPUFuncOp::verifyType() {
1819 if (isKernel() && getFunctionType().getNumResults() != 0)
1820 return emitOpError() <<
"expected void return type for kernel function";
1826 LogicalResult GPUFuncOp::verifyBody() {
1828 return emitOpError() <<
"expected body with at least one block";
1829 unsigned numFuncArguments = getNumArguments();
1830 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1831 unsigned numBlockArguments = front().getNumArguments();
1832 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1833 return emitOpError() <<
"expected at least "
1834 << numFuncArguments + numWorkgroupAttributions
1835 <<
" arguments to body region";
1838 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1839 Type blockArgType = front().getArgument(i).getType();
1840 if (funcArgTypes[i] != blockArgType)
1841 return emitOpError() <<
"expected body region argument #" << i
1842 <<
" to be of type " << funcArgTypes[i] <<
", got "
1847 GPUDialect::getWorkgroupAddressSpace())) ||
1849 GPUDialect::getPrivateAddressSpace())))
1860 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1862 FunctionType funType =
function.getFunctionType();
1864 if (funType.getNumResults() != getOperands().size())
1865 return emitOpError()
1866 .append(
"expected ", funType.getNumResults(),
" result operands")
1867 .attachNote(
function.getLoc())
1868 .append(
"return type declared here");
1871 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1872 auto [type, operand] = pair.value();
1873 if (type != operand.getType())
1874 return emitOpError() <<
"unexpected type `" << operand.getType()
1875 <<
"' for operand #" << pair.index();
1885 StringRef name, ArrayAttr targets,
1890 props.targets = targets;
1892 props.offloadingHandler = offloadingHandler;
1898 build(builder, result, name,
1899 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1903 bool GPUModuleOp::hasTarget(
Attribute target) {
1904 if (ArrayAttr targets = getTargetsAttr())
1905 return llvm::count(targets.getValue(), target);
1910 ArrayAttr &targetsAttr = getProperties().targets;
1916 auto targets = getOperation()->getAttrOfType<ArrayAttr>(
"targets");
1921 for (
auto target : targets) {
1922 if (
auto verifyTargetAttr =
1923 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1924 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
1935 Attribute offloadingHandler, ArrayAttr objects) {
1939 properties.objects = objects;
1940 if (offloadingHandler)
1941 properties.offloadingHandler = offloadingHandler;
1943 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1948 build(builder, result, name, offloadingHandler,
1949 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1960 if (!offloadingHandler)
1968 printer << '<' << offloadingHandler << '>
';
1971 //===----------------------------------------------------------------------===//
1973 //===----------------------------------------------------------------------===//
1975 LogicalResult MemcpyOp::verify() {
1976 auto srcType = getSrc().getType();
1977 auto dstType = getDst().getType();
1979 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1980 return emitOpError("arguments have incompatible element type");
1982 if (failed(verifyCompatibleShape(srcType, dstType)))
1983 return emitOpError("arguments have incompatible shape");
1992 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1993 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1995 LogicalResult matchAndRewrite(MemcpyOp op,
1996 PatternRewriter &rewriter) const override {
1997 Value dest = op.getDst();
1998 Operation *destDefOp = dest.getDefiningOp();
1999 // `dest` must be defined by an op having Allocate memory effect in order to
2000 // perform the folding.
2002 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
2004 // We can erase `op` iff `dest` has no other use apart from its
2005 // use by `op` and dealloc ops.
2006 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2007 return user != op &&
2008 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2011 // We can perform the folding if and only if op has a single async
2012 // dependency and produces an async token as result, or if it does not have
2013 // any async dependency and does not produce any async token result.
2014 if (op.getAsyncDependencies().size() > 1 ||
2015 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2016 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2018 rewriter.replaceOp(op, op.getAsyncDependencies());
2023 } // end anonymous namespace
2025 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2026 MLIRContext *context) {
2027 results.add<EraseTrivialCopyOp>(context);
2030 //===----------------------------------------------------------------------===//
2031 // GPU_SubgroupMmaLoadMatrixOp
2032 //===----------------------------------------------------------------------===//
2034 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2035 auto srcType = getSrcMemref().getType();
2036 auto resType = getRes().getType();
2037 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2038 auto operand = resMatrixType.getOperand();
2039 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2041 if (!srcMemrefType.isLastDimUnitStride())
2043 "expected source memref most minor dim must have unit stride");
2045 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2046 return emitError("only AOp, BOp and COp can be loaded");
2051 //===----------------------------------------------------------------------===//
2052 // GPU_SubgroupMmaStoreMatrixOp
2053 //===----------------------------------------------------------------------===//
2055 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2056 auto srcType = getSrc().getType();
2057 auto dstType = getDstMemref().getType();
2058 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2059 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2061 if (!dstMemrefType.isLastDimUnitStride())
2063 "expected destination memref most minor dim must have unit stride");
2065 if (srcMatrixType.getOperand() != "COp")
2067 "expected the operand matrix being stored to have 'COp
' operand type");
2072 //===----------------------------------------------------------------------===//
2073 // GPU_SubgroupMmaComputeOp
2074 //===----------------------------------------------------------------------===//
2076 LogicalResult SubgroupMmaComputeOp::verify() {
2077 enum OperandMap { A, B, C };
2078 SmallVector<MMAMatrixType, 3> opTypes;
2079 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2080 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2081 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2083 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2084 opTypes[C].getOperand() != "COp")
2085 return emitError("operands must be in the order AOp, BOp, COp");
2087 ArrayRef<int64_t> aShape, bShape, cShape;
2088 aShape = opTypes[A].getShape();
2089 bShape = opTypes[B].getShape();
2090 cShape = opTypes[C].getShape();
2092 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2093 bShape[1] != cShape[1])
2094 return emitError("operand shapes do not satisfy matmul constraints");
2099 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2100 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2101 return memref::foldMemRefCast(*this);
2104 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2105 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2106 return memref::foldMemRefCast(*this);
2109 //===----------------------------------------------------------------------===//
2111 //===----------------------------------------------------------------------===//
2118 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2120 using OpRewritePattern::OpRewritePattern;
2122 LogicalResult matchAndRewrite(WaitOp op,
2123 PatternRewriter &rewriter) const final {
2124 auto predicate = [](Value value) {
2125 auto waitOp = value.getDefiningOp<WaitOp>();
2126 return waitOp && waitOp->getNumOperands() == 0;
2128 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2130 SmallVector<Value> validOperands;
2131 for (Value operand : op->getOperands()) {
2132 if (predicate(operand))
2134 validOperands.push_back(operand);
2136 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2148 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2150 using OpRewritePattern::OpRewritePattern;
2152 LogicalResult matchAndRewrite(WaitOp op,
2153 PatternRewriter &rewriter) const final {
2154 // Erase gpu.wait ops that neither have any async dependencies nor return
2156 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2157 rewriter.eraseOp(op);
2160 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2161 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2162 op.getAsyncToken()) {
2163 rewriter.replaceOp(op, op.getAsyncDependencies());
2166 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2167 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2168 rewriter.eraseOp(op);
2175 } // end anonymous namespace
2177 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2178 MLIRContext *context) {
2179 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2182 //===----------------------------------------------------------------------===//
2184 //===----------------------------------------------------------------------===//
2186 LogicalResult AllocOp::verify() {
2187 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2189 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2190 return emitOpError("dimension operand count does not equal memref "
2191 "dynamic dimension count");
2193 unsigned numSymbols = 0;
2194 if (!memRefType.getLayout().isIdentity())
2195 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2196 if (getSymbolOperands().size() != numSymbols) {
2198 "symbol operand count does not equal memref symbol count");
2208 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2209 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2211 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2212 PatternRewriter &rewriter) const override {
2213 std::optional<int64_t> index = dimOp.getConstantIndex();
2217 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2218 if (!memrefType || index.value() >= memrefType.getRank() ||
2219 !memrefType.isDynamicDim(index.value()))
2222 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2226 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2227 memrefType.getDynamicDimIndex(index.value()));
2228 rewriter.replaceOp(dimOp, substituteOp);
2235 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2236 MLIRContext *context) {
2237 results.add<SimplifyDimOfAllocOp>(context);
2240 //===----------------------------------------------------------------------===//
2241 // GPU object attribute
2242 //===----------------------------------------------------------------------===//
2244 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2245 Attribute target, CompilationTarget format,
2246 StringAttr object, DictionaryAttr properties,
2247 KernelTableAttr kernels) {
2249 return emitError() << "the target attribute cannot be null";
2250 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2252 return emitError() << "the target attribute must implement or promise the "
2253 "`gpu::TargetAttrInterface`";
2257 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2258 StringAttr &object) {
2259 std::optional<CompilationTarget> formatResult;
2260 StringRef enumKeyword;
2261 auto loc = odsParser.getCurrentLocation();
2262 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2263 formatResult = CompilationTarget::Fatbin;
2264 if (!formatResult &&
2266 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2267 odsParser.parseEqual())
2268 return odsParser.emitError(loc, "expected an equal sign");
2270 return odsParser.emitError(loc, "expected keyword for GPU object format");
2271 FailureOr<StringAttr> objectResult =
2272 FieldParser<StringAttr>::parse(odsParser);
2273 if (failed(objectResult))
2274 return odsParser.emitError(odsParser.getCurrentLocation(),
2275 "failed to parse GPU_ObjectAttr parameter "
2276 "'
object' which is to be a `StringAttr`");
2277 format = *formatResult;
2278 object = *objectResult;
2282 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2283 StringAttr object) {
2284 if (format != CompilationTarget::Fatbin)
2285 odsParser << stringifyEnum(format) << " = ";
2286 odsParser << object;
2290 //===----------------------------------------------------------------------===//
2291 // GPU select object attribute
2292 //===----------------------------------------------------------------------===//
2295 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2297 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2299 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2300 if (intAttr.getInt() < 0) {
2301 return emitError() << "the object index must be positive";
2303 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2305 << "the target attribute must be a GPU Target attribute";
2311 //===----------------------------------------------------------------------===//
2312 // DynamicSharedMemoryOp
2313 //===----------------------------------------------------------------------===//
2315 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2316 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2317 return emitOpError() << "must be inside an op with symbol table";
2319 MemRefType memrefType = getResultMemref().getType();
2320 // Check address space
2321 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2322 return emitOpError() << "address space must be "
2323 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2324 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2326 if (memrefType.hasStaticShape()) {
2327 return emitOpError() << "result memref type must be memref<?xi8, "
2328 "#gpu.address_space<workgroup>>";
2333 //===----------------------------------------------------------------------===//
2334 // GPU WarpExecuteOnLane0Op
2335 //===----------------------------------------------------------------------===//
2337 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2338 p << "(" << getLaneid() << ")";
2340 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2341 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2342 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2344 if (!getArgs().empty())
2345 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2346 if (!getResults().empty())
2347 p << " -> (" << getResults().getTypes() << ')
';
2349 p.printRegion(getRegion(),
2350 /*printEntryBlockArgs=*/true,
2351 /*printBlockTerminators=*/!getResults().empty());
2352 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2355 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2356 OperationState &result) {
2357 // Create the region.
2358 result.regions.reserve(1);
2359 Region *warpRegion = result.addRegion();
2361 auto &builder = parser.getBuilder();
2362 OpAsmParser::UnresolvedOperand laneId;
2364 // Parse predicate operand.
2365 if (parser.parseLParen() ||
2366 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2367 parser.parseRParen())
2371 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2372 parser.parseRSquare())
2374 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2375 builder.getContext())),
2376 builder.getI64IntegerAttr(warpSize));
2378 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2381 llvm::SMLoc inputsOperandsLoc;
2382 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2383 SmallVector<Type> inputTypes;
2384 if (succeeded(parser.parseOptionalKeyword("args"))) {
2385 if (parser.parseLParen())
2388 inputsOperandsLoc = parser.getCurrentLocation();
2389 if (parser.parseOperandList(inputsOperands) ||
2390 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2393 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2397 // Parse optional results type list.
2398 if (parser.parseOptionalArrowTypeList(result.types))
2400 // Parse the region.
2401 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2404 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2406 // Parse the optional attribute list.
2407 if (parser.parseOptionalAttrDict(result.attributes))
2412 void WarpExecuteOnLane0Op::getSuccessorRegions(
2413 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2414 if (!point.isParent()) {
2415 regions.push_back(RegionSuccessor(getResults()));
2419 // The warp region is always executed
2420 regions.push_back(RegionSuccessor(&getWarpRegion()));
2423 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2424 TypeRange resultTypes, Value laneId,
2426 build(builder, result, resultTypes, laneId, warpSize,
2427 /*operands=*/{}, /*argTypes=*/{});
2430 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2431 TypeRange resultTypes, Value laneId,
2432 int64_t warpSize, ValueRange args,
2433 TypeRange blockArgTypes) {
2434 result.addOperands(laneId);
2435 result.addAttribute(getAttributeNames()[0],
2436 builder.getI64IntegerAttr(warpSize));
2437 result.addTypes(resultTypes);
2438 result.addOperands(args);
2439 assert(args.size() == blockArgTypes.size());
2440 OpBuilder::InsertionGuard guard(builder);
2441 Region *warpRegion = result.addRegion();
2442 Block *block = builder.createBlock(warpRegion);
2443 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2444 block->addArgument(type, arg.getLoc());
2449 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2450 int64_t warpSize, Operation *op) {
2451 // If the types matches there is no distribution.
2452 if (expanded == distributed)
2454 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2455 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2456 if (!expandedVecType || !distributedVecType)
2457 return op->emitOpError("expected vector type for distributed operands.");
2458 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2459 expandedVecType.getElementType() != distributedVecType.getElementType())
2460 return op->emitOpError(
2461 "expected distributed vectors to have same rank and element type.");
2463 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2464 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2465 int64_t eDim = expandedVecType.getDimSize(i);
2466 int64_t dDim = distributedVecType.getDimSize(i);
2469 if (eDim % dDim != 0)
2470 return op->emitOpError()
2471 << "expected expanded vector dimension #" << i << " (" << eDim
2472 << ") to be a multipler of the distributed vector dimension ("
2474 scales[i] = eDim / dDim;
2476 if (std::accumulate(scales.begin(), scales.end(), 1,
2477 std::multiplies<int64_t>()) != warpSize)
2478 return op->emitOpError()
2479 << "incompatible distribution dimensions from " << expandedVecType
2480 << " to " << distributedVecType << " with warp size = " << warpSize;
2485 LogicalResult WarpExecuteOnLane0Op::verify() {
2486 if (getArgs().size() != getWarpRegion().getNumArguments())
2488 "expected same number op arguments and block arguments.");
2489 gpu::YieldOp yield = getTerminator();
2490 if (yield.getNumOperands() != getNumResults())
2492 "expected same number of yield operands and return values.");
2493 int64_t warpSize = getWarpSize();
2494 for (auto [regionArg, arg] :
2495 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2496 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2497 warpSize, getOperation())))
2500 for (auto [yieldOperand, result] :
2501 llvm::zip_equal(yield.getOperands(), getResults())) {
2502 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2503 warpSize, getOperation())))
2508 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2510 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2513 gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2514 return cast<gpu::YieldOp>(getBody()->getTerminator());
2517 //===----------------------------------------------------------------------===//
2518 // GPU_SubgroupBroadcastOp
2519 //===----------------------------------------------------------------------===//
2521 void gpu::SubgroupBroadcastOp::inferResultRanges(
2522 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2523 setResultRange(getResult(), argRanges.front());
2526 Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2527 switch (getBroadcastType()) {
2528 case BroadcastType::first_active_lane:
2529 // Cannot speculate first_lane broadcast, because speculating it across
2530 // control flow can change the active lanes.
2531 return Speculation::NotSpeculatable;
2532 case BroadcastType::specific_lane:
2533 // Speculation should be safe as long as we inside structured control flow.
2534 return Speculation::Speculatable;
2538 LogicalResult gpu::SubgroupBroadcastOp::verify() {
2539 switch (getBroadcastType()) {
2540 case BroadcastType::first_active_lane:
2542 return emitOpError()
2543 << "lane can only be specified for `specific_lane` broadcast";
2545 case BroadcastType::specific_lane:
2547 return emitOpError()
2548 << "lane must be specified for `specific_lane` broadcast";
2553 //===----------------------------------------------------------------------===//
2554 // GPU KernelMetadataAttr
2555 //===----------------------------------------------------------------------===//
2557 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2558 DictionaryAttr metadata) {
2559 assert(kernel && "invalid kernel");
2560 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2561 kernel.getAllArgAttrs(), metadata);
2565 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2566 FunctionOpInterface kernel,
2567 DictionaryAttr metadata) {
2568 assert(kernel && "invalid kernel");
2569 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2570 kernel.getAllArgAttrs(), metadata);
2574 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2577 NamedAttrList attrList;
2578 if (DictionaryAttr dict = getMetadata())
2579 attrList.append(dict);
2580 attrList.append(attrs);
2581 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2582 attrList.getDictionary(getContext()));
2586 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2587 StringAttr name, Type functionType,
2588 ArrayAttr argAttrs, DictionaryAttr metadata) {
2590 return emitError() << "the kernel name can't be empty
";
2592 if (llvm::any_of(argAttrs, [](Attribute attr) {
2593 return !llvm::isa<DictionaryAttr>(attr);
2596 << "all attributes in the array must be a
dictionary attribute
";
2601 //===----------------------------------------------------------------------===//
2602 // GPU KernelTableAttr
2603 //===----------------------------------------------------------------------===//
2605 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2606 ArrayRef<KernelMetadataAttr> kernels,
2608 // Note that `is_sorted` is always only invoked once even with assertions ON.
2609 assert((!isSorted || llvm::is_sorted(kernels)) &&
2610 "expected a sorted kernel array
");
2611 // Immediately return the attribute if the array is sorted.
2612 if (isSorted || llvm::is_sorted(kernels))
2613 return Base::get(context, kernels);
2615 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2616 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2617 return Base::get(context, kernelsTmp);
2620 KernelTableAttr KernelTableAttr::getChecked(
2621 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2622 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2623 // Note that `is_sorted` is always only invoked once even with assertions ON.
2624 assert((!isSorted || llvm::is_sorted(kernels)) &&
2625 "expected a sorted kernel array
");
2626 // Immediately return the attribute if the array is sorted.
2627 if (isSorted || llvm::is_sorted(kernels))
2628 return Base::getChecked(emitError, context, kernels);
2630 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2631 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2632 return Base::getChecked(emitError, context, kernelsTmp);
2636 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2637 ArrayRef<KernelMetadataAttr> kernels) {
2638 if (kernels.size() < 2)
2640 // Check that the kernels are uniquely named.
2641 if (std::adjacent_find(kernels.begin(), kernels.end(),
2642 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2643 return l.getName() == r.getName();
2644 }) != kernels.end()) {
2645 return emitError() << "expected all kernels to be uniquely named
";
2650 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2651 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2652 return found ? *iterator : KernelMetadataAttr();
2655 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2656 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2657 return found ? *iterator : KernelMetadataAttr();
2660 //===----------------------------------------------------------------------===//
2661 // GPU target options
2662 //===----------------------------------------------------------------------===//
2664 TargetOptions::TargetOptions(
2665 StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2666 StringRef cmdOptions, StringRef elfSection,
2667 CompilationTarget compilationTarget,
2668 function_ref<SymbolTable *()> getSymbolTableCallback,
2669 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2670 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2671 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2672 function_ref<void(StringRef)> isaCallback)
2673 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2674 cmdOptions, elfSection, compilationTarget,
2675 getSymbolTableCallback, initialLlvmIRCallback,
2676 linkedLlvmIRCallback, optimizedLlvmIRCallback,
2679 TargetOptions::TargetOptions(
2680 TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2681 StringRef cmdOptions, StringRef elfSection,
2682 CompilationTarget compilationTarget,
2683 function_ref<SymbolTable *()> getSymbolTableCallback,
2684 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2685 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2686 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2687 function_ref<void(StringRef)> isaCallback)
2688 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2689 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2690 compilationTarget(compilationTarget),
2691 getSymbolTableCallback(getSymbolTableCallback),
2692 initialLlvmIRCallback(initialLlvmIRCallback),
2693 linkedLlvmIRCallback(linkedLlvmIRCallback),
2694 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2695 isaCallback(isaCallback), typeID(typeID) {}
2697 TypeID TargetOptions::getTypeID() const { return typeID; }
2699 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2701 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2702 return librariesToLink;
2705 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2707 StringRef TargetOptions::getELFSection() const { return elfSection; }
2709 SymbolTable *TargetOptions::getSymbolTable() const {
2710 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2713 function_ref<void(llvm::Module &)>
2714 TargetOptions::getInitialLlvmIRCallback() const {
2715 return initialLlvmIRCallback;
2718 function_ref<void(llvm::Module &)>
2719 TargetOptions::getLinkedLlvmIRCallback() const {
2720 return linkedLlvmIRCallback;
2723 function_ref<void(llvm::Module &)>
2724 TargetOptions::getOptimizedLlvmIRCallback() const {
2725 return optimizedLlvmIRCallback;
2728 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2732 CompilationTarget TargetOptions::getCompilationTarget() const {
2733 return compilationTarget;
2736 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2737 return CompilationTarget::Fatbin;
2740 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2741 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2742 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2743 llvm::StringSaver stringSaver(options.first);
2744 StringRef opts = cmdOptions;
2745 // For a correct tokenization of the command line options `opts` must be
2746 // unquoted, otherwise the tokenization function returns a single string: the
2747 // unquoted `cmdOptions` -which is not the desired behavior.
2748 // Remove any quotes if they are at the beginning and end of the string:
2749 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2750 opts.consume_front("\
""), opts.consume_back(
"\"");
2751 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2752 opts.consume_front(
"'"), opts.consume_back(
"'");
2754 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2757 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2763 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2768 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2770 size_t startPos =
cmdOptions.find(startsWith);
2771 if (startPos == std::string::npos)
2782 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2783 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2785 #define GET_ATTRDEF_CLASSES
2786 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2788 #define GET_OP_CLASSES
2789 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2791 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....