35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/ErrorHandling.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/StringSaver.h"
48 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
54 int64_t GPUBlockMappingAttr::getMappingId()
const {
55 return static_cast<int64_t
>(getBlock());
58 bool GPUBlockMappingAttr::isLinearMapping()
const {
59 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
62 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
63 return isLinearMapping()
64 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
68 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
69 return static_cast<int64_t
>(getWarpgroup());
72 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
73 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
76 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
77 return isLinearMapping()
78 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
82 int64_t GPUWarpMappingAttr::getMappingId()
const {
83 return static_cast<int64_t
>(getWarp());
86 bool GPUWarpMappingAttr::isLinearMapping()
const {
87 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
90 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
91 return isLinearMapping()
92 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
96 int64_t GPUThreadMappingAttr::getMappingId()
const {
97 return static_cast<int64_t
>(getThread());
100 bool GPUThreadMappingAttr::isLinearMapping()
const {
101 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
104 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
105 return isLinearMapping()
106 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
110 int64_t GPULaneMappingAttr::getMappingId()
const {
111 return static_cast<int64_t
>(getLane());
114 bool GPULaneMappingAttr::isLinearMapping()
const {
115 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
118 int64_t GPULaneMappingAttr::getRelativeIndex()
const {
119 return isLinearMapping()
120 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
124 int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
136 Value GPUMappingMaskAttr::createLogicalLinearMappingId(
142 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
143 filter = arith::SubIOp::create(b, loc, filter, one);
144 Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
145 return math::CtPopOp::create(b, loc, filteredId);
158 Value GPUMappingMaskAttr::createIsActiveIdPredicate(
164 Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
165 Value filtered = arith::AndIOp::create(b, loc, mask, filter);
167 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
171 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
172 return static_cast<int64_t
>(getAddressSpace());
175 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
176 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
179 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
180 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
197 elementType, operand);
211 return elementType.
isF16() || elementType.
isF32() ||
220 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
221 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
223 if (shape.size() != 2)
224 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
228 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
237 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
240 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
241 return gpuAttr.getValue() == getWorkgroupAddressSpace();
245 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
246 Attribute memorySpace = type.getMemorySpace();
247 return isWorkgroupMemoryAddressSpace(memorySpace);
250 bool GPUDialect::isKernel(
Operation *op) {
251 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
252 return static_cast<bool>(isKernelAttr);
268 void GPUDialect::initialize() {
269 addTypes<AsyncTokenType>();
270 addTypes<MMAMatrixType>();
271 addTypes<SparseDnTensorHandleType>();
272 addTypes<SparseSpMatHandleType>();
273 addTypes<SparseSpGEMMOpHandleType>();
276 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
279 #define GET_ATTRDEF_LIST
280 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
282 addInterfaces<GPUInlinerInterface>();
283 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
285 declarePromisedInterfaces<
286 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
287 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
288 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
294 return "sparse.dntensor_handle";
296 return "sparse.spmat_handle";
298 return "sparse.spgemmop_handle";
300 llvm_unreachable(
"unknown sparse handle kind");
312 if (keyword ==
"async.token")
315 if (keyword ==
"mma_matrix") {
344 shape, elementType, operand);
362 .Case<SparseDnTensorHandleType>([&](
Type) {
365 .Case<SparseSpMatHandleType>(
367 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
373 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
376 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
378 .DefaultUnreachable(
"unexpected 'gpu' type kind");
383 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
386 " must be a dense i32 array");
387 if (array.size() != 3)
389 " must contain exactly 3 elements");
393 LogicalResult GPUDialect::verifyOperationAttribute(
Operation *op,
395 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
397 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
399 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
400 attr.
getName() != getContainerModuleAttrName())
403 auto module = dyn_cast<ModuleOp>(op);
406 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
407 << ModuleOp::getOperationName() <<
'\'';
409 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
412 if (!launchOp->getParentOp() ||
413 launchOp->getParentOp()->getParentOp() != module)
418 if (!launchOp->getAttrOfType<SymbolRefAttr>(
419 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
423 StringAttr kernelContainerName = launchOp.getKernelModuleName();
424 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
425 if (!kernelContainer)
427 <<
"kernel container '" << kernelContainerName.getValue()
431 if (isa<BinaryOp>(kernelContainer))
434 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
436 return launchOp.emitOpError()
437 <<
"kernel module '" << kernelContainerName.getValue()
441 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
444 << launchOp.getKernel() <<
"' is undefined";
445 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
446 if (!kernelConvertedFunction) {
448 <<
"referenced kernel '" << launchOp.getKernel()
449 <<
"' is not a function";
450 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
455 GPUDialect::getKernelFuncAttrName()))
456 return launchOp.emitOpError(
"kernel function is missing the '")
457 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
462 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
463 if (!kernelGPUFunction)
466 unsigned actualNumArguments = launchOp.getNumKernelOperands();
467 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
468 if (expectedNumArguments != actualNumArguments)
469 return launchOp.emitOpError(
"got ")
470 << actualNumArguments <<
" kernel operands but expected "
471 << expectedNumArguments;
473 auto functionType = kernelGPUFunction.getFunctionType();
474 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
475 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
476 return launchOp.emitOpError(
"type of function argument ")
477 << i <<
" does not match";
484 return walkResult.wasInterrupted() ? failure() : success();
497 return parser.
emitError(loc,
"needs to be named when marked 'async'");
512 if (asyncDependencies.empty())
516 printer << llvm::interleaved_array(asyncDependencies);
540 ArrayAttr attributes = {}) {
544 p <<
' ' << keyword <<
'(';
545 llvm::interleaveComma(
548 p << v <<
" : " << v.
getType();
550 size_t attributionIndex = pair.index();
551 DictionaryAttr attrs;
552 if (attributes && attributionIndex < attributes.size())
553 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
563 gpu::AddressSpace memorySpace) {
564 for (
Value v : attributions) {
565 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
567 return op->
emitOpError() <<
"expected memref type in attribution";
572 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
575 if (addressSpace.getValue() != memorySpace)
577 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
578 <<
" in attribution";
589 using Kind = gpu::AllReduceOperation;
590 if (llvm::is_contained(
591 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
593 if (!isa<FloatType>(resType))
597 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
598 Kind::AND, Kind::OR, Kind::XOR},
600 if (!isa<IntegerType>(resType))
607 LogicalResult gpu::AllReduceOp::verifyRegions() {
608 if (getBody().empty() != getOp().has_value())
609 return emitError(
"expected either an op attribute or a non-empty body");
610 if (!getBody().empty()) {
611 if (getBody().getNumArguments() != 2)
612 return emitError(
"expected two region arguments");
613 for (
auto argument : getBody().getArguments()) {
614 if (argument.getType() !=
getType())
615 return emitError(
"incorrect region argument type");
617 unsigned yieldCount = 0;
618 for (
Block &block : getBody()) {
619 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
620 if (yield.getNumOperands() != 1)
621 return emitError(
"expected one gpu.yield operand");
622 if (yield.getOperand(0).getType() !=
getType())
623 return emitError(
"incorrect gpu.yield type");
628 return emitError(
"expected gpu.yield op in region");
630 gpu::AllReduceOperation opName = *getOp();
632 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
633 <<
"` reduction operation is not compatible with type "
642 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
646 Region &body = launchOp.getBody();
647 assert(!body.
empty() &&
"Invalid region");
664 AllReduceOperationAttr &attr) {
667 std::optional<AllReduceOperation> op =
668 gpu::symbolizeAllReduceOperation(enumStr);
677 AllReduceOperationAttr attr) {
688 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
689 if (vecTy.isScalable())
690 return emitOpError() <<
"is not compatible with scalable vector types";
692 elemType = vecTy.getElementType();
695 gpu::AllReduceOperation opName = getOp();
697 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
698 <<
"` reduction operation is not compatible with type "
702 auto clusterSize = getClusterSize();
704 uint32_t size = *clusterSize;
705 if (!llvm::isPowerOf2_32(size)) {
706 return emitOpError() <<
"cluster size " << size
707 <<
" is not a power of two";
711 uint32_t stride = getClusterStride();
712 if (stride != 1 && !clusterSize) {
713 return emitOpError() <<
"cluster stride can only be specified if cluster "
716 if (!llvm::isPowerOf2_32(stride)) {
717 return emitOpError() <<
"cluster stride " << stride
718 <<
" is not a power of two";
724 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
725 if (getClusterSize() == 1)
742 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
746 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
764 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
774 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
783 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
784 getBlockSizeY, getBlockSizeZ});
791 if (dynamicSharedMemorySize)
806 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
809 for (
Type argTy : workgroupAttributions)
811 for (
Type argTy : privateAttributions)
815 segmentSizes.front() = asyncDependencies.size();
816 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
817 segmentSizes[7] = clusterSizeX ? 1 : 0;
818 segmentSizes[8] = clusterSizeY ? 1 : 0;
819 segmentSizes[9] = clusterSizeZ ? 1 : 0;
825 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
826 auto args = getBody().getArguments();
831 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
832 auto args = getBody().getArguments();
837 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
838 auto args = getBody().getArguments();
843 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
844 auto args = getBody().getArguments();
845 return KernelDim3{args[9], args[10], args[11]};
848 std::optional<KernelDim3> LaunchOp::getClusterIds() {
849 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
850 if (!hasClusterSize())
852 auto args = getBody().getArguments();
853 return KernelDim3{args[12], args[13], args[14]};
856 std::optional<KernelDim3> LaunchOp::getClusterSize() {
857 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
858 if (!hasClusterSize())
860 auto args = getBody().getArguments();
861 return KernelDim3{args[15], args[16], args[17]};
864 KernelDim3 LaunchOp::getGridSizeOperandValues() {
865 auto operands = getOperands().drop_front(getAsyncDependencies().size());
866 return KernelDim3{operands[0], operands[1], operands[2]};
869 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
870 auto operands = getOperands().drop_front(getAsyncDependencies().size());
871 return KernelDim3{operands[3], operands[4], operands[5]};
874 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
875 auto operands = getOperands().drop_front(getAsyncDependencies().size());
876 if (!hasClusterSize())
878 return KernelDim3{operands[6], operands[7], operands[8]};
882 if (!(hasClusterSize()) &&
883 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
884 return emitOpError() <<
"cluster size must be all present";
888 LogicalResult LaunchOp::verifyRegions() {
892 if (!getBody().empty()) {
893 if (getBody().getNumArguments() <
894 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
895 return emitOpError(
"unexpected number of region arguments");
900 GPUDialect::getWorkgroupAddressSpace())) ||
902 GPUDialect::getPrivateAddressSpace())))
907 for (
Block &block : getBody()) {
910 if (block.back().getNumSuccessors() != 0)
912 if (!isa<gpu::TerminatorOp>(&block.back())) {
915 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
916 "' or a terminator with successors")
917 .attachNote(getLoc())
918 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
922 if (getNumResults() == 0 && getAsyncToken())
923 return emitOpError(
"needs to be named when async keyword is specified");
934 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
935 p << size.
x <<
" = " << operands.
x <<
", ";
936 p << size.
y <<
" = " << operands.
y <<
", ";
937 p << size.
z <<
" = " << operands.
z <<
')';
941 if (getAsyncToken()) {
943 if (!getAsyncDependencies().empty())
944 p <<
" [" << getAsyncDependencies() <<
']';
947 if (hasClusterSize()) {
948 p <<
' ' << getClustersKeyword();
950 getClusterSizeOperandValues().value(),
951 getClusterIds().value());
953 p <<
' ' << getBlocksKeyword();
956 p <<
' ' << getThreadsKeyword();
959 if (getDynamicSharedMemorySize())
960 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
961 << getDynamicSharedMemorySize();
964 StringRef moduleAttrName = getModuleAttrName();
965 if (
auto module = getModule()) {
966 p <<
' ' << moduleAttrName <<
'(';
971 StringRef functionAttrName = getFunctionAttrName();
972 if (
auto function = getFunction()) {
973 p <<
' ' << functionAttrName <<
'(';
985 LaunchOp::getOperandSegmentSizeAttr(),
986 getNumWorkgroupAttributionsAttrName(),
987 moduleAttrName, functionAttrName});
1001 assert(indices.size() == 3 &&
"space for three indices expected");
1007 std::move(args.begin(), args.end(), indices.begin());
1009 for (
int i = 0; i < 3; ++i) {
1034 sizes(LaunchOp::kNumConfigOperands);
1038 LaunchOp::kNumConfigRegionAttributes);
1042 Type asyncTokenType;
1049 result.
types.push_back(asyncTokenType);
1051 bool hasCluster =
false;
1056 regionArgs.resize(18);
1065 regionArgsRef.slice(15, 3),
1066 regionArgsRef.slice(12, 3)))
1074 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1076 regionArgsRef.slice(6, 3),
1077 regionArgsRef.slice(0, 3)) ||
1078 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1080 regionArgsRef.slice(9, 3),
1081 regionArgsRef.slice(3, 3)) ||
1087 bool hasDynamicSharedMemorySize =
false;
1089 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1090 hasDynamicSharedMemorySize =
true;
1099 StringRef moduleAttrName = getModuleAttrName(result.
name);
1109 StringRef functionAttrName = getFunctionAttrName(result.
name);
1127 LaunchOp::kNumConfigRegionAttributes + 6, index);
1130 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1132 arg.
ssaName = std::get<0>(ssaValueAndType);
1133 arg.
type = std::get<1>(ssaValueAndType);
1134 regionArguments.push_back(arg);
1145 unsigned numWorkgroupAttrs = regionArguments.size() -
1146 LaunchOp::kNumConfigRegionAttributes -
1147 (hasCluster ? 6 : 0);
1148 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1165 segmentSizes.front() = asyncDependencies.size();
1168 segmentSizes[7] = 0;
1169 segmentSizes[8] = 0;
1170 segmentSizes[9] = 0;
1172 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1173 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1187 bool simplified =
false;
1188 auto constPropIdUses = [&](
Value id,
Value size) {
1192 if (
id.getUses().empty())
1204 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1205 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1206 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1207 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1208 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1209 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1211 return success(simplified);
1223 auto attrName = getNumWorkgroupAttributionsAttrName();
1224 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1225 (*this)->setAttr(attrName,
1227 return getBody().insertArgument(
1228 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1236 return getBody().addArgument(type, loc);
1244 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1248 std::optional<KernelDim3> clusterSize) {
1249 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1250 "expected a symbol reference with a single nested reference");
1258 if (clusterSize.has_value())
1259 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1260 if (dynamicSharedMemorySize)
1265 prop.kernel = kernelSymbol;
1266 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1268 llvm::fill(prop.operandSegmentSizes, 1);
1269 prop.operandSegmentSizes[0] = asyncDependencies.size();
1270 if (!clusterSize.has_value()) {
1271 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1272 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1273 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1275 prop.operandSegmentSizes[segmentSizesLen - 3] =
1276 dynamicSharedMemorySize ? 1 : 0;
1277 prop.operandSegmentSizes[segmentSizesLen - 2] =
1278 static_cast<int32_t
>(kernelOperands.size());
1279 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1287 std::optional<KernelDim3> clusterSize) {
1288 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1291 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1292 build(builder, result, kernelSymbol, gridSize,
getBlockSize,
1293 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1294 asyncDependencies, clusterSize);
1301 std::optional<KernelDim3> clusterSize) {
1305 if (clusterSize.has_value())
1306 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1307 if (dynamicSharedMemorySize)
1313 prop.kernel = kernel;
1314 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1316 llvm::fill(prop.operandSegmentSizes, 1);
1317 prop.operandSegmentSizes[0] = 0;
1318 if (!clusterSize.has_value()) {
1319 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1320 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1321 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1323 prop.operandSegmentSizes[segmentSizesLen - 3] =
1324 dynamicSharedMemorySize ? 1 : 0;
1325 prop.operandSegmentSizes[segmentSizesLen - 2] =
1326 static_cast<int32_t
>(kernelOperands.size());
1327 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1330 StringAttr LaunchFuncOp::getKernelModuleName() {
1334 StringAttr LaunchFuncOp::getKernelName() {
1338 unsigned LaunchFuncOp::getNumKernelOperands() {
1339 return getKernelOperands().size();
1342 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1343 return getKernelOperands()[i];
1346 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1347 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1348 return KernelDim3{operands[0], operands[1], operands[2]};
1351 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1352 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1353 return KernelDim3{operands[3], operands[4], operands[5]};
1356 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1357 assert(hasClusterSize() &&
1358 "cluster size is not set, check hasClusterSize() first");
1359 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1360 return KernelDim3{operands[6], operands[7], operands[8]};
1364 auto module = (*this)->getParentOfType<ModuleOp>();
1366 return emitOpError(
"expected to belong to a module");
1368 if (!module->getAttrOfType<UnitAttr>(
1369 GPUDialect::getContainerModuleAttrName()))
1370 return emitOpError(
"expected the closest surrounding module to have the '" +
1371 GPUDialect::getContainerModuleAttrName() +
1374 if (hasClusterSize()) {
1375 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1377 return emitOpError()
1378 <<
"expects types of the cluster dimensions must be the same";
1386 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1387 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1394 if (clusterValue.has_value()) {
1395 clusterXTy = clusterYTy = clusterZTy = dimTy;
1402 Type clusterYTy,
Type clusterZTy) {
1404 printer <<
": " << dimTy;
1414 auto parseElement = [&]() -> ParseResult {
1415 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1420 parseElement,
" in argument list");
1425 if (operands.empty())
1428 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1429 [&](
const auto &pair) {
1430 auto [operand, type] = pair;
1431 printer << operand <<
" : " << type;
1441 int32_t offset, int32_t width, ShuffleMode mode) {
1442 build(builder, result, value,
1443 arith::ConstantOp::create(builder, result.
location,
1445 arith::ConstantOp::create(builder, result.
location,
1455 uint32_t offset = getOffset();
1456 uint32_t width = getWidth();
1458 if (offset >= width) {
1459 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1472 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1474 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1485 results.
add(eraseRedundantGpuBarrierOps);
1495 auto attrName = getNumWorkgroupAttributionsAttrName();
1496 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1497 (*this)->setAttr(attrName,
1499 return getBody().insertArgument(
1500 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1508 return getBody().addArgument(type, loc);
1512 StringRef name, FunctionType type,
1522 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1529 for (
Type argTy : type.getInputs())
1531 for (
Type argTy : workgroupAttributions)
1533 for (
Type argTy : privateAttributions)
1552 size_t existingArgs = args.size();
1553 ParseResult result =
1559 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1564 attributionAttrs =
nullptr;
1570 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1571 if (!argument.attrs)
1574 attributionAttrsVec.push_back(argument.attrs);
1576 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1592 StringAttr nameAttr;
1599 parser,
false, entryArgs, isVariadic, resultTypes,
1603 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1604 return parser.
emitError(signatureLocation)
1605 <<
"gpu.func requires named arguments";
1612 for (
auto &arg : entryArgs)
1613 argTypes.push_back(arg.
type);
1619 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1620 getResAttrsAttrName(result.
name));
1625 entryArgs, workgroupAttributionAttrs)))
1630 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1631 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1633 if (workgroupAttributionAttrs)
1634 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1635 workgroupAttributionAttrs);
1640 entryArgs, privateAttributionAttrs)))
1642 if (privateAttributionAttrs)
1644 privateAttributionAttrs);
1648 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1665 FunctionType type = getFunctionType();
1671 getWorkgroupAttribAttrs().value_or(
nullptr));
1673 getPrivateAttribAttrs().value_or(
nullptr));
1675 p <<
' ' << getKernelKeyword();
1679 {getNumWorkgroupAttributionsAttrName(),
1680 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1681 getArgAttrsAttrName(), getResAttrsAttrName(),
1682 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1688 StringAttr attrName) {
1689 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1690 if (!allAttrs || index >= allAttrs.size())
1691 return DictionaryAttr();
1692 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1695 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1699 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1704 DictionaryAttr value, StringAttr attrName) {
1706 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1709 elements.append(allAttrs.begin(), allAttrs.end());
1710 while (elements.size() <= index)
1715 elements[index] = value;
1717 op->setAttr(attrName, newValue);
1720 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1721 DictionaryAttr value) {
1725 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1726 DictionaryAttr value) {
1731 StringAttr name, StringAttr attrsName) {
1735 return dict.get(name);
1738 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1740 assert(index < getNumWorkgroupAttributions() &&
1741 "index must map to a workgroup attribution");
1743 getWorkgroupAttribAttrsAttrName());
1746 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1748 assert(index < getNumPrivateAttributions() &&
1749 "index must map to a private attribution");
1751 getPrivateAttribAttrsAttrName());
1755 Attribute value, StringAttr attrsName) {
1760 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1763 bool mustSort =
true;
1764 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1765 if (elems[i].getName() == name) {
1768 std::swap(elems[i], elems[elems.size() - 1]);
1780 elems.emplace_back(name, value);
1783 DictionaryAttr::sortInPlace(elems);
1785 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1789 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1791 assert(index < getNumWorkgroupAttributions() &&
1792 "index must map to a workgroup attribution");
1794 getWorkgroupAttribAttrsAttrName());
1797 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1799 assert(index < getNumPrivateAttributions() &&
1800 "index must map to a private attribution");
1802 getPrivateAttribAttrsAttrName());
1805 LogicalResult GPUFuncOp::verifyType() {
1806 if (isKernel() && getFunctionType().getNumResults() != 0)
1807 return emitOpError() <<
"expected void return type for kernel function";
1813 LogicalResult GPUFuncOp::verifyBody() {
1815 return emitOpError() <<
"expected body with at least one block";
1816 unsigned numFuncArguments = getNumArguments();
1817 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1818 unsigned numBlockArguments = front().getNumArguments();
1819 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1820 return emitOpError() <<
"expected at least "
1821 << numFuncArguments + numWorkgroupAttributions
1822 <<
" arguments to body region";
1825 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1826 Type blockArgType = front().getArgument(i).getType();
1827 if (funcArgTypes[i] != blockArgType)
1828 return emitOpError() <<
"expected body region argument #" << i
1829 <<
" to be of type " << funcArgTypes[i] <<
", got "
1834 GPUDialect::getWorkgroupAddressSpace())) ||
1836 GPUDialect::getPrivateAddressSpace())))
1847 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1849 FunctionType funType =
function.getFunctionType();
1851 if (funType.getNumResults() != getOperands().size())
1852 return emitOpError()
1853 .append(
"expected ", funType.getNumResults(),
" result operands")
1854 .attachNote(
function.getLoc())
1855 .append(
"return type declared here");
1858 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1859 auto [type, operand] = pair.value();
1860 if (type != operand.getType())
1861 return emitOpError() <<
"unexpected type `" << operand.getType()
1862 <<
"' for operand #" << pair.index();
1872 StringRef name, ArrayAttr targets,
1877 props.targets = targets;
1879 props.offloadingHandler = offloadingHandler;
1885 build(builder, result, name,
1886 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1890 bool GPUModuleOp::hasTarget(
Attribute target) {
1891 if (ArrayAttr targets = getTargetsAttr())
1892 return llvm::count(targets.getValue(), target);
1897 ArrayAttr &targetsAttr = getProperties().targets;
1903 auto targets = getOperation()->getAttrOfType<ArrayAttr>(
"targets");
1908 for (
auto target : targets) {
1909 if (
auto verifyTargetAttr =
1910 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1911 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
1922 Attribute offloadingHandler, ArrayAttr objects) {
1926 properties.objects = objects;
1927 if (offloadingHandler)
1928 properties.offloadingHandler = offloadingHandler;
1930 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1935 build(builder, result, name, offloadingHandler,
1936 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1947 if (!offloadingHandler)
1955 printer << '<' << offloadingHandler << '>
';
1958 //===----------------------------------------------------------------------===//
1960 //===----------------------------------------------------------------------===//
1962 LogicalResult MemcpyOp::verify() {
1963 auto srcType = getSrc().getType();
1964 auto dstType = getDst().getType();
1966 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1967 return emitOpError("arguments have incompatible element type");
1969 if (failed(verifyCompatibleShape(srcType, dstType)))
1970 return emitOpError("arguments have incompatible shape");
1979 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1980 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1982 LogicalResult matchAndRewrite(MemcpyOp op,
1983 PatternRewriter &rewriter) const override {
1984 Value dest = op.getDst();
1985 Operation *destDefOp = dest.getDefiningOp();
1986 // `dest` must be defined by an op having Allocate memory effect in order to
1987 // perform the folding.
1989 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1991 // We can erase `op` iff `dest` has no other use apart from its
1992 // use by `op` and dealloc ops.
1993 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1994 return user != op &&
1995 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1998 // We can perform the folding if and only if op has a single async
1999 // dependency and produces an async token as result, or if it does not have
2000 // any async dependency and does not produce any async token result.
2001 if (op.getAsyncDependencies().size() > 1 ||
2002 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2003 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2005 rewriter.replaceOp(op, op.getAsyncDependencies());
2010 } // end anonymous namespace
2012 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2013 MLIRContext *context) {
2014 results.add<EraseTrivialCopyOp>(context);
2017 //===----------------------------------------------------------------------===//
2018 // GPU_SubgroupMmaLoadMatrixOp
2019 //===----------------------------------------------------------------------===//
2021 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2022 auto srcType = getSrcMemref().getType();
2023 auto resType = getRes().getType();
2024 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2025 auto operand = resMatrixType.getOperand();
2026 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2028 if (!srcMemrefType.isLastDimUnitStride())
2030 "expected source memref most minor dim must have unit stride");
2032 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2033 return emitError("only AOp, BOp and COp can be loaded");
2038 //===----------------------------------------------------------------------===//
2039 // GPU_SubgroupMmaStoreMatrixOp
2040 //===----------------------------------------------------------------------===//
2042 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2043 auto srcType = getSrc().getType();
2044 auto dstType = getDstMemref().getType();
2045 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2046 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2048 if (!dstMemrefType.isLastDimUnitStride())
2050 "expected destination memref most minor dim must have unit stride");
2052 if (srcMatrixType.getOperand() != "COp")
2054 "expected the operand matrix being stored to have 'COp
' operand type");
2059 //===----------------------------------------------------------------------===//
2060 // GPU_SubgroupMmaComputeOp
2061 //===----------------------------------------------------------------------===//
2063 LogicalResult SubgroupMmaComputeOp::verify() {
2064 enum OperandMap { A, B, C };
2065 SmallVector<MMAMatrixType, 3> opTypes;
2066 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2067 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2068 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2070 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2071 opTypes[C].getOperand() != "COp")
2072 return emitError("operands must be in the order AOp, BOp, COp");
2074 ArrayRef<int64_t> aShape, bShape, cShape;
2075 aShape = opTypes[A].getShape();
2076 bShape = opTypes[B].getShape();
2077 cShape = opTypes[C].getShape();
2079 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2080 bShape[1] != cShape[1])
2081 return emitError("operand shapes do not satisfy matmul constraints");
2086 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2087 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2088 return memref::foldMemRefCast(*this);
2091 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2092 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2093 return memref::foldMemRefCast(*this);
2096 //===----------------------------------------------------------------------===//
2098 //===----------------------------------------------------------------------===//
2105 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2107 using OpRewritePattern::OpRewritePattern;
2109 LogicalResult matchAndRewrite(WaitOp op,
2110 PatternRewriter &rewriter) const final {
2111 auto predicate = [](Value value) {
2112 auto waitOp = value.getDefiningOp<WaitOp>();
2113 return waitOp && waitOp->getNumOperands() == 0;
2115 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2117 SmallVector<Value> validOperands;
2118 for (Value operand : op->getOperands()) {
2119 if (predicate(operand))
2121 validOperands.push_back(operand);
2123 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2135 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2137 using OpRewritePattern::OpRewritePattern;
2139 LogicalResult matchAndRewrite(WaitOp op,
2140 PatternRewriter &rewriter) const final {
2141 // Erase gpu.wait ops that neither have any async dependencies nor return
2143 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2144 rewriter.eraseOp(op);
2147 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2148 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2149 op.getAsyncToken()) {
2150 rewriter.replaceOp(op, op.getAsyncDependencies());
2153 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2154 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2155 rewriter.eraseOp(op);
2162 } // end anonymous namespace
2164 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2165 MLIRContext *context) {
2166 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2169 //===----------------------------------------------------------------------===//
2171 //===----------------------------------------------------------------------===//
2173 LogicalResult AllocOp::verify() {
2174 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2176 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2177 return emitOpError("dimension operand count does not equal memref "
2178 "dynamic dimension count");
2180 unsigned numSymbols = 0;
2181 if (!memRefType.getLayout().isIdentity())
2182 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2183 if (getSymbolOperands().size() != numSymbols) {
2185 "symbol operand count does not equal memref symbol count");
2195 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2196 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2198 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2199 PatternRewriter &rewriter) const override {
2200 std::optional<int64_t> index = dimOp.getConstantIndex();
2204 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2205 if (!memrefType || index.value() >= memrefType.getRank() ||
2206 !memrefType.isDynamicDim(index.value()))
2209 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2213 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2214 memrefType.getDynamicDimIndex(index.value()));
2215 rewriter.replaceOp(dimOp, substituteOp);
2222 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2223 MLIRContext *context) {
2224 results.add<SimplifyDimOfAllocOp>(context);
2227 //===----------------------------------------------------------------------===//
2228 // GPU object attribute
2229 //===----------------------------------------------------------------------===//
2231 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2232 Attribute target, CompilationTarget format,
2233 StringAttr object, DictionaryAttr properties,
2234 KernelTableAttr kernels) {
2236 return emitError() << "the target attribute cannot be null";
2237 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2239 return emitError() << "the target attribute must implement or promise the "
2240 "`gpu::TargetAttrInterface`";
2244 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2245 StringAttr &object) {
2246 std::optional<CompilationTarget> formatResult;
2247 StringRef enumKeyword;
2248 auto loc = odsParser.getCurrentLocation();
2249 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2250 formatResult = CompilationTarget::Fatbin;
2251 if (!formatResult &&
2253 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2254 odsParser.parseEqual())
2255 return odsParser.emitError(loc, "expected an equal sign");
2257 return odsParser.emitError(loc, "expected keyword for GPU object format");
2258 FailureOr<StringAttr> objectResult =
2259 FieldParser<StringAttr>::parse(odsParser);
2260 if (failed(objectResult))
2261 return odsParser.emitError(odsParser.getCurrentLocation(),
2262 "failed to parse GPU_ObjectAttr parameter "
2263 "'
object' which is to be a `StringAttr`");
2264 format = *formatResult;
2265 object = *objectResult;
2269 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2270 StringAttr object) {
2271 if (format != CompilationTarget::Fatbin)
2272 odsParser << stringifyEnum(format) << " = ";
2273 odsParser << object;
2277 //===----------------------------------------------------------------------===//
2278 // GPU select object attribute
2279 //===----------------------------------------------------------------------===//
2282 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2284 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2286 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2287 if (intAttr.getInt() < 0) {
2288 return emitError() << "the object index must be positive";
2290 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2292 << "the target attribute must be a GPU Target attribute";
2298 //===----------------------------------------------------------------------===//
2299 // DynamicSharedMemoryOp
2300 //===----------------------------------------------------------------------===//
2302 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2303 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2304 return emitOpError() << "must be inside an op with symbol table";
2306 MemRefType memrefType = getResultMemref().getType();
2307 // Check address space
2308 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2309 return emitOpError() << "address space must be "
2310 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2311 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2313 if (memrefType.hasStaticShape()) {
2314 return emitOpError() << "result memref type must be memref<?xi8, "
2315 "#gpu.address_space<workgroup>>";
2320 //===----------------------------------------------------------------------===//
2321 // GPU WarpExecuteOnLane0Op
2322 //===----------------------------------------------------------------------===//
2324 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2325 p << "(" << getLaneid() << ")";
2327 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2328 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2329 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2331 if (!getArgs().empty())
2332 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2333 if (!getResults().empty())
2334 p << " -> (" << getResults().getTypes() << ')
';
2336 p.printRegion(getRegion(),
2337 /*printEntryBlockArgs=*/true,
2338 /*printBlockTerminators=*/!getResults().empty());
2339 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2342 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2343 OperationState &result) {
2344 // Create the region.
2345 result.regions.reserve(1);
2346 Region *warpRegion = result.addRegion();
2348 auto &builder = parser.getBuilder();
2349 OpAsmParser::UnresolvedOperand laneId;
2351 // Parse predicate operand.
2352 if (parser.parseLParen() ||
2353 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2354 parser.parseRParen())
2358 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2359 parser.parseRSquare())
2361 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2362 builder.getContext())),
2363 builder.getI64IntegerAttr(warpSize));
2365 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2368 llvm::SMLoc inputsOperandsLoc;
2369 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2370 SmallVector<Type> inputTypes;
2371 if (succeeded(parser.parseOptionalKeyword("args"))) {
2372 if (parser.parseLParen())
2375 inputsOperandsLoc = parser.getCurrentLocation();
2376 if (parser.parseOperandList(inputsOperands) ||
2377 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2380 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2384 // Parse optional results type list.
2385 if (parser.parseOptionalArrowTypeList(result.types))
2387 // Parse the region.
2388 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2391 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2393 // Parse the optional attribute list.
2394 if (parser.parseOptionalAttrDict(result.attributes))
2399 void WarpExecuteOnLane0Op::getSuccessorRegions(
2400 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2401 if (!point.isParent()) {
2402 regions.push_back(RegionSuccessor(getResults()));
2406 // The warp region is always executed
2407 regions.push_back(RegionSuccessor(&getWarpRegion()));
2410 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2411 TypeRange resultTypes, Value laneId,
2413 build(builder, result, resultTypes, laneId, warpSize,
2414 /*operands=*/{}, /*argTypes=*/{});
2417 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2418 TypeRange resultTypes, Value laneId,
2419 int64_t warpSize, ValueRange args,
2420 TypeRange blockArgTypes) {
2421 result.addOperands(laneId);
2422 result.addAttribute(getAttributeNames()[0],
2423 builder.getI64IntegerAttr(warpSize));
2424 result.addTypes(resultTypes);
2425 result.addOperands(args);
2426 assert(args.size() == blockArgTypes.size());
2427 OpBuilder::InsertionGuard guard(builder);
2428 Region *warpRegion = result.addRegion();
2429 Block *block = builder.createBlock(warpRegion);
2430 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2431 block->addArgument(type, arg.getLoc());
2436 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2437 int64_t warpSize, Operation *op) {
2438 // If the types matches there is no distribution.
2439 if (expanded == distributed)
2441 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2442 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2443 if (!expandedVecType || !distributedVecType)
2444 return op->emitOpError("expected vector type for distributed operands.");
2445 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2446 expandedVecType.getElementType() != distributedVecType.getElementType())
2447 return op->emitOpError(
2448 "expected distributed vectors to have same rank and element type.");
2450 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2451 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2452 int64_t eDim = expandedVecType.getDimSize(i);
2453 int64_t dDim = distributedVecType.getDimSize(i);
2456 if (eDim % dDim != 0)
2457 return op->emitOpError()
2458 << "expected expanded vector dimension #" << i << " (" << eDim
2459 << ") to be a multipler of the distributed vector dimension ("
2461 scales[i] = eDim / dDim;
2463 if (llvm::product_of(scales) != warpSize)
2464 return op->emitOpError()
2465 << "incompatible distribution dimensions from " << expandedVecType
2466 << " to " << distributedVecType << " with warp size = " << warpSize;
2471 LogicalResult WarpExecuteOnLane0Op::verify() {
2472 if (getArgs().size() != getWarpRegion().getNumArguments())
2474 "expected same number op arguments and block arguments.");
2475 gpu::YieldOp yield = getTerminator();
2476 if (yield.getNumOperands() != getNumResults())
2478 "expected same number of yield operands and return values.");
2479 int64_t warpSize = getWarpSize();
2480 for (auto [regionArg, arg] :
2481 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2482 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2483 warpSize, getOperation())))
2486 for (auto [yieldOperand, result] :
2487 llvm::zip_equal(yield.getOperands(), getResults())) {
2488 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2489 warpSize, getOperation())))
2494 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2496 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2499 gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2500 return cast<gpu::YieldOp>(getBody()->getTerminator());
2503 //===----------------------------------------------------------------------===//
2504 // GPU_SubgroupBroadcastOp
2505 //===----------------------------------------------------------------------===//
2507 void gpu::SubgroupBroadcastOp::inferResultRanges(
2508 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2509 setResultRange(getResult(), argRanges.front());
2512 Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2513 switch (getBroadcastType()) {
2514 case BroadcastType::first_active_lane:
2515 // Cannot speculate first_lane broadcast, because speculating it across
2516 // control flow can change the active lanes.
2517 return Speculation::NotSpeculatable;
2518 case BroadcastType::specific_lane:
2519 // Speculation should be safe as long as we inside structured control flow.
2520 return Speculation::Speculatable;
2524 LogicalResult gpu::SubgroupBroadcastOp::verify() {
2525 switch (getBroadcastType()) {
2526 case BroadcastType::first_active_lane:
2528 return emitOpError()
2529 << "lane can only be specified for `specific_lane` broadcast";
2531 case BroadcastType::specific_lane:
2533 return emitOpError()
2534 << "lane must be specified for `specific_lane` broadcast";
2539 //===----------------------------------------------------------------------===//
2540 // GPU KernelMetadataAttr
2541 //===----------------------------------------------------------------------===//
2543 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2544 DictionaryAttr metadata) {
2545 assert(kernel && "invalid kernel");
2546 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2547 kernel.getAllArgAttrs(), metadata);
2551 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2552 FunctionOpInterface kernel,
2553 DictionaryAttr metadata) {
2554 assert(kernel && "invalid kernel");
2555 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2556 kernel.getAllArgAttrs(), metadata);
2560 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2563 NamedAttrList attrList;
2564 if (DictionaryAttr dict = getMetadata())
2565 attrList.append(dict);
2566 attrList.append(attrs);
2567 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2568 attrList.getDictionary(getContext()));
2572 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2573 StringAttr name, Type functionType,
2574 ArrayAttr argAttrs, DictionaryAttr metadata) {
2576 return emitError() << "the kernel name can't be empty
";
2578 if (llvm::any_of(argAttrs, [](Attribute attr) {
2579 return !llvm::isa<DictionaryAttr>(attr);
2582 << "all attributes in the array must be a
dictionary attribute
";
2587 //===----------------------------------------------------------------------===//
2588 // GPU KernelTableAttr
2589 //===----------------------------------------------------------------------===//
2591 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2592 ArrayRef<KernelMetadataAttr> kernels,
2594 // Note that `is_sorted` is always only invoked once even with assertions ON.
2595 assert((!isSorted || llvm::is_sorted(kernels)) &&
2596 "expected a sorted kernel array
");
2597 // Immediately return the attribute if the array is sorted.
2598 if (isSorted || llvm::is_sorted(kernels))
2599 return Base::get(context, kernels);
2601 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2602 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2603 return Base::get(context, kernelsTmp);
2606 KernelTableAttr KernelTableAttr::getChecked(
2607 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2608 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2609 // Note that `is_sorted` is always only invoked once even with assertions ON.
2610 assert((!isSorted || llvm::is_sorted(kernels)) &&
2611 "expected a sorted kernel array
");
2612 // Immediately return the attribute if the array is sorted.
2613 if (isSorted || llvm::is_sorted(kernels))
2614 return Base::getChecked(emitError, context, kernels);
2616 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2617 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2618 return Base::getChecked(emitError, context, kernelsTmp);
2622 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2623 ArrayRef<KernelMetadataAttr> kernels) {
2624 if (kernels.size() < 2)
2626 // Check that the kernels are uniquely named.
2627 if (std::adjacent_find(kernels.begin(), kernels.end(),
2628 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2629 return l.getName() == r.getName();
2630 }) != kernels.end()) {
2631 return emitError() << "expected all kernels to be uniquely named
";
2636 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2637 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2638 return found ? *iterator : KernelMetadataAttr();
2641 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2642 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2643 return found ? *iterator : KernelMetadataAttr();
2646 //===----------------------------------------------------------------------===//
2647 // GPU target options
2648 //===----------------------------------------------------------------------===//
2650 TargetOptions::TargetOptions(
2651 StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2652 StringRef cmdOptions, StringRef elfSection,
2653 CompilationTarget compilationTarget,
2654 function_ref<SymbolTable *()> getSymbolTableCallback,
2655 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2656 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2657 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2658 function_ref<void(StringRef)> isaCallback)
2659 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2660 cmdOptions, elfSection, compilationTarget,
2661 getSymbolTableCallback, initialLlvmIRCallback,
2662 linkedLlvmIRCallback, optimizedLlvmIRCallback,
2665 TargetOptions::TargetOptions(
2666 TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2667 StringRef cmdOptions, StringRef elfSection,
2668 CompilationTarget compilationTarget,
2669 function_ref<SymbolTable *()> getSymbolTableCallback,
2670 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2671 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2672 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2673 function_ref<void(StringRef)> isaCallback)
2674 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2675 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2676 compilationTarget(compilationTarget),
2677 getSymbolTableCallback(getSymbolTableCallback),
2678 initialLlvmIRCallback(initialLlvmIRCallback),
2679 linkedLlvmIRCallback(linkedLlvmIRCallback),
2680 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2681 isaCallback(isaCallback), typeID(typeID) {}
2683 TypeID TargetOptions::getTypeID() const { return typeID; }
2685 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2687 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2688 return librariesToLink;
2691 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2693 StringRef TargetOptions::getELFSection() const { return elfSection; }
2695 SymbolTable *TargetOptions::getSymbolTable() const {
2696 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2699 function_ref<void(llvm::Module &)>
2700 TargetOptions::getInitialLlvmIRCallback() const {
2701 return initialLlvmIRCallback;
2704 function_ref<void(llvm::Module &)>
2705 TargetOptions::getLinkedLlvmIRCallback() const {
2706 return linkedLlvmIRCallback;
2709 function_ref<void(llvm::Module &)>
2710 TargetOptions::getOptimizedLlvmIRCallback() const {
2711 return optimizedLlvmIRCallback;
2714 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2718 CompilationTarget TargetOptions::getCompilationTarget() const {
2719 return compilationTarget;
2722 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2723 return CompilationTarget::Fatbin;
2726 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2727 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2728 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2729 llvm::StringSaver stringSaver(options.first);
2730 StringRef opts = cmdOptions;
2731 // For a correct tokenization of the command line options `opts` must be
2732 // unquoted, otherwise the tokenization function returns a single string: the
2733 // unquoted `cmdOptions` -which is not the desired behavior.
2734 // Remove any quotes if they are at the beginning and end of the string:
2735 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2736 opts.consume_front("\
""), opts.consume_back(
"\"");
2737 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2738 opts.consume_front(
"'"), opts.consume_back(
"'");
2740 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2743 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2749 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2754 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2756 size_t startPos =
cmdOptions.find(startsWith);
2757 if (startPos == std::string::npos)
2768 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2769 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2771 #define GET_ATTRDEF_CLASSES
2772 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2774 #define GET_OP_CLASSES
2775 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2777 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....