34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/StringSaver.h"
45 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
51 int64_t GPUBlockMappingAttr::getMappingId()
const {
52 return static_cast<int64_t
>(getBlock());
55 bool GPUBlockMappingAttr::isLinearMapping()
const {
56 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
59 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
60 return isLinearMapping()
61 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
65 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
66 return static_cast<int64_t
>(getWarpgroup());
69 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
70 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
73 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
74 return isLinearMapping()
75 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
79 int64_t GPUWarpMappingAttr::getMappingId()
const {
80 return static_cast<int64_t
>(getWarp());
83 bool GPUWarpMappingAttr::isLinearMapping()
const {
84 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
87 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
88 return isLinearMapping()
89 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
93 int64_t GPUThreadMappingAttr::getMappingId()
const {
94 return static_cast<int64_t
>(getThread());
97 bool GPUThreadMappingAttr::isLinearMapping()
const {
98 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
101 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
102 return isLinearMapping()
103 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
107 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
108 return static_cast<int64_t
>(getAddressSpace());
111 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
112 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
115 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
116 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
133 elementType, operand);
147 return elementType.
isF16() || elementType.
isF32() ||
156 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
157 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
159 if (shape.size() != 2)
160 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
164 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
173 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
176 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177 return gpuAttr.getValue() == getWorkgroupAddressSpace();
181 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
182 Attribute memorySpace = type.getMemorySpace();
183 return isWorkgroupMemoryAddressSpace(memorySpace);
186 bool GPUDialect::isKernel(
Operation *op) {
187 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
188 return static_cast<bool>(isKernelAttr);
204 void GPUDialect::initialize() {
205 addTypes<AsyncTokenType>();
206 addTypes<MMAMatrixType>();
207 addTypes<SparseDnTensorHandleType>();
208 addTypes<SparseSpMatHandleType>();
209 addTypes<SparseSpGEMMOpHandleType>();
212 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
215 #define GET_ATTRDEF_LIST
216 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
218 addInterfaces<GPUInlinerInterface>();
219 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
221 declarePromisedInterfaces<
222 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
223 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
224 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
230 return "sparse.dntensor_handle";
232 return "sparse.spmat_handle";
234 return "sparse.spgemmop_handle";
236 llvm_unreachable(
"unknown sparse handle kind");
248 if (keyword ==
"async.token")
251 if (keyword ==
"mma_matrix") {
280 shape, elementType, operand);
298 .Case<SparseDnTensorHandleType>([&](
Type) {
301 .Case<SparseSpMatHandleType>(
303 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
309 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
312 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
314 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
319 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
322 " must be a dense i32 array");
323 if (array.size() != 3)
325 " must contain exactly 3 elements");
329 LogicalResult GPUDialect::verifyOperationAttribute(
Operation *op,
331 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
333 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
335 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
336 attr.
getName() != getContainerModuleAttrName())
339 auto module = dyn_cast<ModuleOp>(op);
342 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
343 << ModuleOp::getOperationName() <<
'\'';
345 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
348 if (!launchOp->getParentOp() ||
349 launchOp->getParentOp()->getParentOp() != module)
354 if (!launchOp->getAttrOfType<SymbolRefAttr>(
355 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
359 StringAttr kernelContainerName = launchOp.getKernelModuleName();
360 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
361 if (!kernelContainer)
363 <<
"kernel container '" << kernelContainerName.getValue()
367 if (isa<BinaryOp>(kernelContainer))
370 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
372 return launchOp.emitOpError()
373 <<
"kernel module '" << kernelContainerName.getValue()
377 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
380 << launchOp.getKernel() <<
"' is undefined";
381 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
382 if (!kernelConvertedFunction) {
384 <<
"referenced kernel '" << launchOp.getKernel()
385 <<
"' is not a function";
386 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
391 GPUDialect::getKernelFuncAttrName()))
392 return launchOp.emitOpError(
"kernel function is missing the '")
393 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
398 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
399 if (!kernelGPUFunction)
402 unsigned actualNumArguments = launchOp.getNumKernelOperands();
403 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
404 if (expectedNumArguments != actualNumArguments)
405 return launchOp.emitOpError(
"got ")
406 << actualNumArguments <<
" kernel operands but expected "
407 << expectedNumArguments;
409 auto functionType = kernelGPUFunction.getFunctionType();
410 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
411 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
412 return launchOp.emitOpError(
"type of function argument ")
413 << i <<
" does not match";
420 return walkResult.wasInterrupted() ? failure() : success();
433 return parser.
emitError(loc,
"needs to be named when marked 'async'");
448 if (asyncDependencies.empty())
453 llvm::interleaveComma(asyncDependencies, printer);
482 p <<
' ' << keyword <<
'(';
483 llvm::interleaveComma(
491 gpu::AddressSpace memorySpace) {
492 for (
Value v : attributions) {
493 auto type = llvm::dyn_cast<MemRefType>(v.getType());
495 return op->
emitOpError() <<
"expected memref type in attribution";
500 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
503 if (addressSpace.getValue() != memorySpace)
505 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
506 <<
" in attribution";
517 using Kind = gpu::AllReduceOperation;
518 if (llvm::is_contained(
519 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
521 if (!isa<FloatType>(resType))
525 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
526 Kind::AND, Kind::OR, Kind::XOR},
528 if (!isa<IntegerType>(resType))
535 LogicalResult gpu::AllReduceOp::verifyRegions() {
536 if (getBody().empty() != getOp().has_value())
537 return emitError(
"expected either an op attribute or a non-empty body");
538 if (!getBody().empty()) {
539 if (getBody().getNumArguments() != 2)
540 return emitError(
"expected two region arguments");
541 for (
auto argument : getBody().getArguments()) {
542 if (argument.getType() !=
getType())
543 return emitError(
"incorrect region argument type");
545 unsigned yieldCount = 0;
546 for (
Block &block : getBody()) {
547 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
548 if (yield.getNumOperands() != 1)
549 return emitError(
"expected one gpu.yield operand");
550 if (yield.getOperand(0).getType() !=
getType())
551 return emitError(
"incorrect gpu.yield type");
556 return emitError(
"expected gpu.yield op in region");
558 gpu::AllReduceOperation opName = *getOp();
560 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
561 <<
"` reduction operation is not compatible with type "
570 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
574 Region &body = launchOp.getBody();
575 assert(!body.
empty() &&
"Invalid region");
592 AllReduceOperationAttr &attr) {
595 std::optional<AllReduceOperation> op =
596 gpu::symbolizeAllReduceOperation(enumStr);
605 AllReduceOperationAttr attr) {
616 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
617 if (vecTy.isScalable())
618 return emitOpError() <<
"is not compatible with scalable vector types";
620 elemType = vecTy.getElementType();
623 gpu::AllReduceOperation opName = getOp();
625 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
626 <<
"` reduction operation is not compatible with type "
630 auto clusterSize = getClusterSize();
632 uint32_t size = *clusterSize;
633 if (!llvm::isPowerOf2_32(size)) {
634 return emitOpError() <<
"cluster size " << size
635 <<
" is not a power of two";
639 uint32_t stride = getClusterStride();
640 if (stride != 1 && !clusterSize) {
641 return emitOpError() <<
"cluster stride can only be specified if cluster "
644 if (!llvm::isPowerOf2_32(stride)) {
645 return emitOpError() <<
"cluster stride " << stride
646 <<
" is not a power of two";
652 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
653 if (getClusterSize() == 1)
670 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
674 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
692 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
701 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
710 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
711 getBlockSizeY, getBlockSizeZ});
718 if (dynamicSharedMemorySize)
727 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
730 for (
Type argTy : workgroupAttributions)
732 for (
Type argTy : privateAttributions)
736 segmentSizes.front() = asyncDependencies.size();
737 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
738 segmentSizes[7] = clusterSizeX ? 1 : 0;
739 segmentSizes[8] = clusterSizeY ? 1 : 0;
740 segmentSizes[9] = clusterSizeZ ? 1 : 0;
746 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
747 auto args = getBody().getArguments();
752 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
753 auto args = getBody().getArguments();
758 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
759 auto args = getBody().getArguments();
764 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
765 auto args = getBody().getArguments();
766 return KernelDim3{args[9], args[10], args[11]};
769 std::optional<KernelDim3> LaunchOp::getClusterIds() {
770 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
771 if (!hasClusterSize())
773 auto args = getBody().getArguments();
774 return KernelDim3{args[12], args[13], args[14]};
777 std::optional<KernelDim3> LaunchOp::getClusterSize() {
778 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
779 if (!hasClusterSize())
781 auto args = getBody().getArguments();
782 return KernelDim3{args[15], args[16], args[17]};
785 KernelDim3 LaunchOp::getGridSizeOperandValues() {
786 auto operands = getOperands().drop_front(getAsyncDependencies().size());
787 return KernelDim3{operands[0], operands[1], operands[2]};
790 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
791 auto operands = getOperands().drop_front(getAsyncDependencies().size());
792 return KernelDim3{operands[3], operands[4], operands[5]};
795 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
796 auto operands = getOperands().drop_front(getAsyncDependencies().size());
797 if (!hasClusterSize())
799 return KernelDim3{operands[6], operands[7], operands[8]};
803 if (!(hasClusterSize()) &&
804 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
805 return emitOpError() <<
"cluster size must be all present";
809 LogicalResult LaunchOp::verifyRegions() {
813 if (!getBody().empty()) {
814 if (getBody().getNumArguments() <
815 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
816 return emitOpError(
"unexpected number of region arguments");
821 GPUDialect::getWorkgroupAddressSpace())) ||
823 GPUDialect::getPrivateAddressSpace())))
828 for (
Block &block : getBody()) {
831 if (block.back().getNumSuccessors() != 0)
833 if (!isa<gpu::TerminatorOp>(&block.back())) {
836 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
837 "' or a terminator with successors")
838 .attachNote(getLoc())
839 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
843 if (getNumResults() == 0 && getAsyncToken())
844 return emitOpError(
"needs to be named when async keyword is specified");
855 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
856 p << size.
x <<
" = " << operands.
x <<
", ";
857 p << size.
y <<
" = " << operands.
y <<
", ";
858 p << size.
z <<
" = " << operands.
z <<
')';
862 if (getAsyncToken()) {
864 if (!getAsyncDependencies().empty())
865 p <<
" [" << getAsyncDependencies() <<
']';
868 if (hasClusterSize()) {
869 p <<
' ' << getClustersKeyword();
871 getClusterSizeOperandValues().value(),
872 getClusterIds().value());
874 p <<
' ' << getBlocksKeyword();
877 p <<
' ' << getThreadsKeyword();
880 if (getDynamicSharedMemorySize())
881 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
882 << getDynamicSharedMemorySize();
891 LaunchOp::getOperandSegmentSizeAttr(),
892 getNumWorkgroupAttributionsAttrName()});
906 assert(indices.size() == 3 &&
"space for three indices expected");
912 std::move(args.begin(), args.end(), indices.begin());
914 for (
int i = 0; i < 3; ++i) {
936 sizes(LaunchOp::kNumConfigOperands);
943 LaunchOp::kNumConfigRegionAttributes);
954 result.
types.push_back(asyncTokenType);
956 bool hasCluster =
false;
961 regionArgs.resize(18);
970 regionArgsRef.slice(15, 3),
971 regionArgsRef.slice(12, 3)))
979 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
981 regionArgsRef.slice(6, 3),
982 regionArgsRef.slice(0, 3)) ||
983 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
985 regionArgsRef.slice(9, 3),
986 regionArgsRef.slice(3, 3)) ||
992 bool hasDynamicSharedMemorySize =
false;
994 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
995 hasDynamicSharedMemorySize =
true;
1011 LaunchOp::kNumConfigRegionAttributes + 6, index);
1014 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1016 arg.
ssaName = std::get<0>(ssaValueAndType);
1017 arg.
type = std::get<1>(ssaValueAndType);
1018 regionArguments.push_back(arg);
1029 unsigned numWorkgroupAttrs = regionArguments.size() -
1030 LaunchOp::kNumConfigRegionAttributes -
1031 (hasCluster ? 6 : 0);
1032 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1049 segmentSizes.front() = asyncDependencies.size();
1052 segmentSizes[7] = 0;
1053 segmentSizes[8] = 0;
1054 segmentSizes[9] = 0;
1056 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1057 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1071 bool simplified =
false;
1072 auto constPropIdUses = [&](
Value id,
Value size) {
1076 if (
id.getUses().empty())
1083 rewriter.
create<arith::ConstantIndexOp>(op.getLoc(), 0);
1088 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1089 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1090 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1091 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1092 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1093 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1095 return success(simplified);
1107 auto attrName = getNumWorkgroupAttributionsAttrName();
1108 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1109 (*this)->setAttr(attrName,
1111 return getBody().insertArgument(
1112 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1120 return getBody().addArgument(type, loc);
1128 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1132 std::optional<KernelDim3> clusterSize) {
1133 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1134 "expected a symbol reference with a single nested reference");
1142 if (clusterSize.has_value())
1143 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1144 if (dynamicSharedMemorySize)
1149 prop.kernel = kernelSymbol;
1150 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1152 for (
auto &sz : prop.operandSegmentSizes)
1154 prop.operandSegmentSizes[0] = asyncDependencies.size();
1155 if (!clusterSize.has_value()) {
1156 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1157 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1158 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1160 prop.operandSegmentSizes[segmentSizesLen - 3] =
1161 dynamicSharedMemorySize ? 1 : 0;
1162 prop.operandSegmentSizes[segmentSizesLen - 2] =
1163 static_cast<int32_t
>(kernelOperands.size());
1164 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1172 std::optional<KernelDim3> clusterSize) {
1173 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1176 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1177 build(builder, result, kernelSymbol, gridSize,
getBlockSize,
1178 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1179 asyncDependencies, clusterSize);
1186 std::optional<KernelDim3> clusterSize) {
1190 if (clusterSize.has_value())
1191 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1192 if (dynamicSharedMemorySize)
1198 prop.kernel = kernel;
1199 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1201 for (
auto &sz : prop.operandSegmentSizes)
1203 prop.operandSegmentSizes[0] = 0;
1204 if (!clusterSize.has_value()) {
1205 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1206 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1207 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1209 prop.operandSegmentSizes[segmentSizesLen - 3] =
1210 dynamicSharedMemorySize ? 1 : 0;
1211 prop.operandSegmentSizes[segmentSizesLen - 2] =
1212 static_cast<int32_t
>(kernelOperands.size());
1213 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1216 StringAttr LaunchFuncOp::getKernelModuleName() {
1220 StringAttr LaunchFuncOp::getKernelName() {
1224 unsigned LaunchFuncOp::getNumKernelOperands() {
1225 return getKernelOperands().size();
1228 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1229 return getKernelOperands()[i];
1232 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1233 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1234 return KernelDim3{operands[0], operands[1], operands[2]};
1237 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1238 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1239 return KernelDim3{operands[3], operands[4], operands[5]};
1242 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1243 assert(hasClusterSize() &&
1244 "cluster size is not set, check hasClusterSize() first");
1245 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1246 return KernelDim3{operands[6], operands[7], operands[8]};
1250 auto module = (*this)->getParentOfType<ModuleOp>();
1252 return emitOpError(
"expected to belong to a module");
1254 if (!module->getAttrOfType<UnitAttr>(
1255 GPUDialect::getContainerModuleAttrName()))
1256 return emitOpError(
"expected the closest surrounding module to have the '" +
1257 GPUDialect::getContainerModuleAttrName() +
1260 if (hasClusterSize()) {
1261 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1263 return emitOpError()
1264 <<
"expects types of the cluster dimensions must be the same";
1272 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1273 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1280 if (clusterValue.has_value()) {
1281 clusterXTy = clusterYTy = clusterZTy = dimTy;
1288 Type clusterYTy,
Type clusterZTy) {
1290 printer <<
": " << dimTy;
1300 auto parseElement = [&]() -> ParseResult {
1301 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1306 parseElement,
" in argument list");
1311 if (operands.empty())
1314 llvm::interleaveComma(llvm::zip(operands, types), printer,
1315 [&](
const auto &pair) {
1328 int32_t offset, int32_t width, ShuffleMode mode) {
1329 build(builder, result, value,
1344 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1346 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1357 results.
add(eraseRedundantGpuBarrierOps);
1367 auto attrName = getNumWorkgroupAttributionsAttrName();
1368 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1369 (*this)->setAttr(attrName,
1371 return getBody().insertArgument(
1372 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1380 return getBody().addArgument(type, loc);
1384 StringRef name, FunctionType type,
1394 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1401 for (
Type argTy : type.getInputs())
1403 for (
Type argTy : workgroupAttributions)
1405 for (
Type argTy : privateAttributions)
1424 size_t existingArgs = args.size();
1425 ParseResult result =
1431 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1436 attributionAttrs =
nullptr;
1442 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1443 if (!argument.attrs)
1446 attributionAttrsVec.push_back(argument.attrs);
1448 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1464 StringAttr nameAttr;
1471 parser,
false, entryArgs, isVariadic, resultTypes,
1475 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1476 return parser.
emitError(signatureLocation)
1477 <<
"gpu.func requires named arguments";
1484 for (
auto &arg : entryArgs)
1485 argTypes.push_back(arg.
type);
1491 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1492 getResAttrsAttrName(result.
name));
1497 entryArgs, workgroupAttributionAttrs)))
1502 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1503 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1505 if (workgroupAttributionAttrs)
1506 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1507 workgroupAttributionAttrs);
1512 entryArgs, privateAttributionAttrs)))
1514 if (privateAttributionAttrs)
1516 privateAttributionAttrs);
1520 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1535 ArrayAttr attributes) {
1539 p <<
' ' << keyword <<
'(';
1540 llvm::interleaveComma(
1543 p << v <<
" : " << v.
getType();
1545 size_t attributionIndex = pair.index();
1546 DictionaryAttr attrs;
1547 if (attributes && attributionIndex < attributes.size())
1548 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1559 FunctionType type = getFunctionType();
1565 getWorkgroupAttribAttrs().value_or(
nullptr));
1567 getPrivateAttribAttrs().value_or(
nullptr));
1569 p <<
' ' << getKernelKeyword();
1573 {getNumWorkgroupAttributionsAttrName(),
1574 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1575 getArgAttrsAttrName(), getResAttrsAttrName(),
1576 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1582 StringAttr attrName) {
1583 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1584 if (!allAttrs || index >= allAttrs.size())
1585 return DictionaryAttr();
1586 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1589 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1593 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1598 DictionaryAttr value, StringAttr attrName) {
1600 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1603 elements.append(allAttrs.begin(), allAttrs.end());
1604 while (elements.size() <= index)
1609 elements[index] = value;
1611 op->setAttr(attrName, newValue);
1614 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1615 DictionaryAttr value) {
1619 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1620 DictionaryAttr value) {
1625 StringAttr name, StringAttr attrsName) {
1629 return dict.get(name);
1632 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1634 assert(index < getNumWorkgroupAttributions() &&
1635 "index must map to a workgroup attribution");
1637 getWorkgroupAttribAttrsAttrName());
1640 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1642 assert(index < getNumPrivateAttributions() &&
1643 "index must map to a private attribution");
1645 getPrivateAttribAttrsAttrName());
1649 Attribute value, StringAttr attrsName) {
1654 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1657 bool mustSort =
true;
1658 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1659 if (elems[i].getName() == name) {
1662 std::swap(elems[i], elems[elems.size() - 1]);
1674 elems.emplace_back(name, value);
1677 DictionaryAttr::sortInPlace(elems);
1679 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1683 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1685 assert(index < getNumWorkgroupAttributions() &&
1686 "index must map to a workgroup attribution");
1688 getWorkgroupAttribAttrsAttrName());
1691 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1693 assert(index < getNumPrivateAttributions() &&
1694 "index must map to a private attribution");
1696 getPrivateAttribAttrsAttrName());
1699 LogicalResult GPUFuncOp::verifyType() {
1700 if (isKernel() && getFunctionType().getNumResults() != 0)
1701 return emitOpError() <<
"expected void return type for kernel function";
1707 LogicalResult GPUFuncOp::verifyBody() {
1709 return emitOpError() <<
"expected body with at least one block";
1710 unsigned numFuncArguments = getNumArguments();
1711 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1712 unsigned numBlockArguments = front().getNumArguments();
1713 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1714 return emitOpError() <<
"expected at least "
1715 << numFuncArguments + numWorkgroupAttributions
1716 <<
" arguments to body region";
1719 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1720 Type blockArgType = front().getArgument(i).getType();
1721 if (funcArgTypes[i] != blockArgType)
1722 return emitOpError() <<
"expected body region argument #" << i
1723 <<
" to be of type " << funcArgTypes[i] <<
", got "
1728 GPUDialect::getWorkgroupAddressSpace())) ||
1730 GPUDialect::getPrivateAddressSpace())))
1741 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1743 FunctionType funType =
function.getFunctionType();
1745 if (funType.getNumResults() != getOperands().size())
1746 return emitOpError()
1747 .append(
"expected ", funType.getNumResults(),
" result operands")
1748 .attachNote(
function.getLoc())
1749 .append(
"return type declared here");
1752 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1753 auto [type, operand] = pair.value();
1754 if (type != operand.getType())
1755 return emitOpError() <<
"unexpected type `" << operand.getType()
1756 <<
"' for operand #" << pair.index();
1766 StringRef name, ArrayAttr targets,
1771 props.targets = targets;
1773 props.offloadingHandler = offloadingHandler;
1779 build(builder, result, name,
1780 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1784 bool GPUModuleOp::hasTarget(
Attribute target) {
1785 if (ArrayAttr targets = getTargetsAttr())
1786 return llvm::count(targets.getValue(), target);
1791 ArrayAttr &targetsAttr = getProperties().targets;
1800 Attribute offloadingHandler, ArrayAttr objects) {
1804 properties.objects = objects;
1805 if (offloadingHandler)
1806 properties.offloadingHandler = offloadingHandler;
1808 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1813 build(builder, result, name, offloadingHandler,
1814 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1825 if (!offloadingHandler)
1833 printer << '<' << offloadingHandler << '>
';
1836 //===----------------------------------------------------------------------===//
1838 //===----------------------------------------------------------------------===//
1840 LogicalResult MemcpyOp::verify() {
1841 auto srcType = getSrc().getType();
1842 auto dstType = getDst().getType();
1844 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1845 return emitOpError("arguments have incompatible element type");
1847 if (failed(verifyCompatibleShape(srcType, dstType)))
1848 return emitOpError("arguments have incompatible shape");
1857 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1858 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1860 LogicalResult matchAndRewrite(MemcpyOp op,
1861 PatternRewriter &rewriter) const override {
1862 Value dest = op.getDst();
1863 Operation *destDefOp = dest.getDefiningOp();
1864 // `dest` must be defined by an op having Allocate memory effect in order to
1865 // perform the folding.
1867 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1869 // We can erase `op` iff `dest` has no other use apart from its
1870 // use by `op` and dealloc ops.
1871 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1872 return user != op &&
1873 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1876 // We can perform the folding if and only if op has a single async
1877 // dependency and produces an async token as result, or if it does not have
1878 // any async dependency and does not produce any async token result.
1879 if (op.getAsyncDependencies().size() > 1 ||
1880 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1881 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1883 rewriter.replaceOp(op, op.getAsyncDependencies());
1888 } // end anonymous namespace
1890 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1891 MLIRContext *context) {
1892 results.add<EraseTrivialCopyOp>(context);
1895 //===----------------------------------------------------------------------===//
1896 // GPU_SubgroupMmaLoadMatrixOp
1897 //===----------------------------------------------------------------------===//
1899 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1900 auto srcType = getSrcMemref().getType();
1901 auto resType = getRes().getType();
1902 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1903 auto operand = resMatrixType.getOperand();
1904 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1906 if (!srcMemrefType.isLastDimUnitStride())
1908 "expected source memref most minor dim must have unit stride");
1910 if (operand != "AOp" && operand != "BOp" && operand != "COp")
1911 return emitError("only AOp, BOp and COp can be loaded");
1916 //===----------------------------------------------------------------------===//
1917 // GPU_SubgroupMmaStoreMatrixOp
1918 //===----------------------------------------------------------------------===//
1920 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1921 auto srcType = getSrc().getType();
1922 auto dstType = getDstMemref().getType();
1923 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1924 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1926 if (!dstMemrefType.isLastDimUnitStride())
1928 "expected destination memref most minor dim must have unit stride");
1930 if (srcMatrixType.getOperand() != "COp")
1932 "expected the operand matrix being stored to have 'COp
' operand type");
1937 //===----------------------------------------------------------------------===//
1938 // GPU_SubgroupMmaComputeOp
1939 //===----------------------------------------------------------------------===//
1941 LogicalResult SubgroupMmaComputeOp::verify() {
1942 enum OperandMap { A, B, C };
1943 SmallVector<MMAMatrixType, 3> opTypes;
1944 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1945 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1946 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1948 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1949 opTypes[C].getOperand() != "COp")
1950 return emitError("operands must be in the order AOp, BOp, COp");
1952 ArrayRef<int64_t> aShape, bShape, cShape;
1953 aShape = opTypes[A].getShape();
1954 bShape = opTypes[B].getShape();
1955 cShape = opTypes[C].getShape();
1957 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1958 bShape[1] != cShape[1])
1959 return emitError("operand shapes do not satisfy matmul constraints");
1964 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1965 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1966 return memref::foldMemRefCast(*this);
1969 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1970 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1971 return memref::foldMemRefCast(*this);
1974 //===----------------------------------------------------------------------===//
1976 //===----------------------------------------------------------------------===//
1983 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1985 using OpRewritePattern::OpRewritePattern;
1987 LogicalResult matchAndRewrite(WaitOp op,
1988 PatternRewriter &rewriter) const final {
1989 auto predicate = [](Value value) {
1990 auto waitOp = value.getDefiningOp<WaitOp>();
1991 return waitOp && waitOp->getNumOperands() == 0;
1993 if (llvm::none_of(op.getAsyncDependencies(), predicate))
1995 SmallVector<Value> validOperands;
1996 for (Value operand : op->getOperands()) {
1997 if (predicate(operand))
1999 validOperands.push_back(operand);
2001 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2013 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2015 using OpRewritePattern::OpRewritePattern;
2017 LogicalResult matchAndRewrite(WaitOp op,
2018 PatternRewriter &rewriter) const final {
2019 // Erase gpu.wait ops that neither have any async dependencies nor return
2021 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2022 rewriter.eraseOp(op);
2025 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2026 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2027 op.getAsyncToken()) {
2028 rewriter.replaceOp(op, op.getAsyncDependencies());
2031 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2032 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2033 rewriter.eraseOp(op);
2040 } // end anonymous namespace
2042 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2043 MLIRContext *context) {
2044 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2047 //===----------------------------------------------------------------------===//
2049 //===----------------------------------------------------------------------===//
2051 LogicalResult AllocOp::verify() {
2052 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2054 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2055 return emitOpError("dimension operand count does not equal memref "
2056 "dynamic dimension count");
2058 unsigned numSymbols = 0;
2059 if (!memRefType.getLayout().isIdentity())
2060 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2061 if (getSymbolOperands().size() != numSymbols) {
2063 "symbol operand count does not equal memref symbol count");
2073 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2074 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2076 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2077 PatternRewriter &rewriter) const override {
2078 std::optional<int64_t> index = dimOp.getConstantIndex();
2082 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2083 if (!memrefType || index.value() >= memrefType.getRank() ||
2084 !memrefType.isDynamicDim(index.value()))
2087 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2091 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2092 memrefType.getDynamicDimIndex(index.value()));
2093 rewriter.replaceOp(dimOp, substituteOp);
2100 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2101 MLIRContext *context) {
2102 results.add<SimplifyDimOfAllocOp>(context);
2105 //===----------------------------------------------------------------------===//
2106 // GPU object attribute
2107 //===----------------------------------------------------------------------===//
2109 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2110 Attribute target, CompilationTarget format,
2111 StringAttr object, DictionaryAttr properties,
2112 KernelTableAttr kernels) {
2114 return emitError() << "the target attribute cannot be null";
2115 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2117 return emitError() << "the target attribute must implement or promise the "
2118 "`gpu::TargetAttrInterface`";
2122 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2123 StringAttr &object) {
2124 std::optional<CompilationTarget> formatResult;
2125 StringRef enumKeyword;
2126 auto loc = odsParser.getCurrentLocation();
2127 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2128 formatResult = CompilationTarget::Fatbin;
2129 if (!formatResult &&
2131 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2132 odsParser.parseEqual())
2133 return odsParser.emitError(loc, "expected an equal sign");
2135 return odsParser.emitError(loc, "expected keyword for GPU object format");
2136 FailureOr<StringAttr> objectResult =
2137 FieldParser<StringAttr>::parse(odsParser);
2138 if (failed(objectResult))
2139 return odsParser.emitError(odsParser.getCurrentLocation(),
2140 "failed to parse GPU_ObjectAttr parameter "
2141 "'
object' which is to be a `StringAttr`");
2142 format = *formatResult;
2143 object = *objectResult;
2147 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2148 StringAttr object) {
2149 if (format != CompilationTarget::Fatbin)
2150 odsParser << stringifyEnum(format) << " = ";
2151 odsParser << object;
2155 //===----------------------------------------------------------------------===//
2156 // GPU select object attribute
2157 //===----------------------------------------------------------------------===//
2160 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2162 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2164 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2165 if (intAttr.getInt() < 0) {
2166 return emitError() << "the object index must be positive";
2168 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2170 << "the target attribute must be a GPU Target attribute";
2176 //===----------------------------------------------------------------------===//
2177 // DynamicSharedMemoryOp
2178 //===----------------------------------------------------------------------===//
2180 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2181 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2182 return emitOpError() << "must be inside an op with symbol table";
2184 MemRefType memrefType = getResultMemref().getType();
2185 // Check address space
2186 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2187 return emitOpError() << "address space must be "
2188 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2189 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2191 if (memrefType.hasStaticShape()) {
2192 return emitOpError() << "result memref type must be memref<?xi8, "
2193 "#gpu.address_space<workgroup>>";
2198 //===----------------------------------------------------------------------===//
2199 // GPU WarpExecuteOnLane0Op
2200 //===----------------------------------------------------------------------===//
2202 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2203 p << "(" << getLaneid() << ")";
2205 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2206 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2207 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2209 if (!getArgs().empty())
2210 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2211 if (!getResults().empty())
2212 p << " -> (" << getResults().getTypes() << ')
';
2214 p.printRegion(getRegion(),
2215 /*printEntryBlockArgs=*/true,
2216 /*printBlockTerminators=*/!getResults().empty());
2217 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2220 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2221 OperationState &result) {
2222 // Create the region.
2223 result.regions.reserve(1);
2224 Region *warpRegion = result.addRegion();
2226 auto &builder = parser.getBuilder();
2227 OpAsmParser::UnresolvedOperand laneId;
2229 // Parse predicate operand.
2230 if (parser.parseLParen() ||
2231 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2232 parser.parseRParen())
2236 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2237 parser.parseRSquare())
2239 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2240 builder.getContext())),
2241 builder.getI64IntegerAttr(warpSize));
2243 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2246 llvm::SMLoc inputsOperandsLoc;
2247 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2248 SmallVector<Type> inputTypes;
2249 if (succeeded(parser.parseOptionalKeyword("args"))) {
2250 if (parser.parseLParen())
2253 inputsOperandsLoc = parser.getCurrentLocation();
2254 if (parser.parseOperandList(inputsOperands) ||
2255 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2258 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2262 // Parse optional results type list.
2263 if (parser.parseOptionalArrowTypeList(result.types))
2265 // Parse the region.
2266 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2269 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2271 // Parse the optional attribute list.
2272 if (parser.parseOptionalAttrDict(result.attributes))
2277 void WarpExecuteOnLane0Op::getSuccessorRegions(
2278 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2279 if (!point.isParent()) {
2280 regions.push_back(RegionSuccessor(getResults()));
2284 // The warp region is always executed
2285 regions.push_back(RegionSuccessor(&getWarpRegion()));
2288 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2289 TypeRange resultTypes, Value laneId,
2291 build(builder, result, resultTypes, laneId, warpSize,
2292 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2295 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2296 TypeRange resultTypes, Value laneId,
2297 int64_t warpSize, ValueRange args,
2298 TypeRange blockArgTypes) {
2299 result.addOperands(laneId);
2300 result.addAttribute(getAttributeNames()[0],
2301 builder.getI64IntegerAttr(warpSize));
2302 result.addTypes(resultTypes);
2303 result.addOperands(args);
2304 assert(args.size() == blockArgTypes.size());
2305 OpBuilder::InsertionGuard guard(builder);
2306 Region *warpRegion = result.addRegion();
2307 Block *block = builder.createBlock(warpRegion);
2308 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2309 block->addArgument(type, arg.getLoc());
2314 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2315 int64_t warpSize, Operation *op) {
2316 // If the types matches there is no distribution.
2317 if (expanded == distributed)
2319 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2320 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2321 if (!expandedVecType || !distributedVecType)
2322 return op->emitOpError("expected vector type for distributed operands.");
2323 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2324 expandedVecType.getElementType() != distributedVecType.getElementType())
2325 return op->emitOpError(
2326 "expected distributed vectors to have same rank and element type.");
2328 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2329 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2330 int64_t eDim = expandedVecType.getDimSize(i);
2331 int64_t dDim = distributedVecType.getDimSize(i);
2334 if (eDim % dDim != 0)
2335 return op->emitOpError()
2336 << "expected expanded vector dimension #" << i << " (" << eDim
2337 << ") to be a multipler of the distributed vector dimension ("
2339 scales[i] = eDim / dDim;
2341 if (std::accumulate(scales.begin(), scales.end(), 1,
2342 std::multiplies<int64_t>()) != warpSize)
2343 return op->emitOpError()
2344 << "incompatible distribution dimensions from " << expandedVecType
2345 << " to " << distributedVecType << " with warp size = " << warpSize;
2350 LogicalResult WarpExecuteOnLane0Op::verify() {
2351 if (getArgs().size() != getWarpRegion().getNumArguments())
2353 "expected same number op arguments and block arguments.");
2355 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2356 if (yield.getNumOperands() != getNumResults())
2358 "expected same number of yield operands and return values.");
2359 int64_t warpSize = getWarpSize();
2360 for (auto [regionArg, arg] :
2361 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2362 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2363 warpSize, getOperation())))
2366 for (auto [yieldOperand, result] :
2367 llvm::zip_equal(yield.getOperands(), getResults())) {
2368 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2369 warpSize, getOperation())))
2374 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2376 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2379 //===----------------------------------------------------------------------===//
2380 // GPU KernelMetadataAttr
2381 //===----------------------------------------------------------------------===//
2383 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2384 DictionaryAttr metadata) {
2385 assert(kernel && "invalid kernel");
2386 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2387 kernel.getAllArgAttrs(), metadata);
2391 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2392 FunctionOpInterface kernel,
2393 DictionaryAttr metadata) {
2394 assert(kernel && "invalid kernel");
2395 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2396 kernel.getAllArgAttrs(), metadata);
2400 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2403 NamedAttrList attrList;
2404 if (DictionaryAttr dict = getMetadata())
2405 attrList.append(dict);
2406 attrList.append(attrs);
2407 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2408 attrList.getDictionary(getContext()));
2412 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2413 StringAttr name, Type functionType,
2414 ArrayAttr argAttrs, DictionaryAttr metadata) {
2416 return emitError() << "the kernel name can't be empty
";
2418 if (llvm::any_of(argAttrs, [](Attribute attr) {
2419 return !llvm::isa<DictionaryAttr>(attr);
2422 << "all attributes in the array must be a dictionary attribute
";
2427 //===----------------------------------------------------------------------===//
2428 // GPU KernelTableAttr
2429 //===----------------------------------------------------------------------===//
2431 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2432 ArrayRef<KernelMetadataAttr> kernels,
2434 // Note that `is_sorted` is always only invoked once even with assertions ON.
2435 assert((!isSorted || llvm::is_sorted(kernels)) &&
2436 "expected a sorted kernel array
");
2437 // Immediately return the attribute if the array is sorted.
2438 if (isSorted || llvm::is_sorted(kernels))
2439 return Base::get(context, kernels);
2441 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2442 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2443 return Base::get(context, kernelsTmp);
2446 KernelTableAttr KernelTableAttr::getChecked(
2447 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2448 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2449 // Note that `is_sorted` is always only invoked once even with assertions ON.
2450 assert((!isSorted || llvm::is_sorted(kernels)) &&
2451 "expected a sorted kernel array
");
2452 // Immediately return the attribute if the array is sorted.
2453 if (isSorted || llvm::is_sorted(kernels))
2454 return Base::getChecked(emitError, context, kernels);
2456 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2457 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2458 return Base::getChecked(emitError, context, kernelsTmp);
2462 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2463 ArrayRef<KernelMetadataAttr> kernels) {
2464 if (kernels.size() < 2)
2466 // Check that the kernels are uniquely named.
2467 if (std::adjacent_find(kernels.begin(), kernels.end(),
2468 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2469 return l.getName() == r.getName();
2470 }) != kernels.end()) {
2471 return emitError() << "expected all kernels to be uniquely named
";
2476 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2477 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2478 return found ? *iterator : KernelMetadataAttr();
2481 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2482 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2483 return found ? *iterator : KernelMetadataAttr();
2486 //===----------------------------------------------------------------------===//
2487 // GPU target options
2488 //===----------------------------------------------------------------------===//
2490 TargetOptions::TargetOptions(
2491 StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2492 StringRef cmdOptions, StringRef elfSection,
2493 CompilationTarget compilationTarget,
2494 function_ref<SymbolTable *()> getSymbolTableCallback,
2495 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2496 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2497 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2498 function_ref<void(StringRef)> isaCallback)
2499 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2500 cmdOptions, elfSection, compilationTarget,
2501 getSymbolTableCallback, initialLlvmIRCallback,
2502 linkedLlvmIRCallback, optimizedLlvmIRCallback,
2505 TargetOptions::TargetOptions(
2506 TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2507 StringRef cmdOptions, StringRef elfSection,
2508 CompilationTarget compilationTarget,
2509 function_ref<SymbolTable *()> getSymbolTableCallback,
2510 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2511 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2512 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2513 function_ref<void(StringRef)> isaCallback)
2514 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2515 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2516 compilationTarget(compilationTarget),
2517 getSymbolTableCallback(getSymbolTableCallback),
2518 initialLlvmIRCallback(initialLlvmIRCallback),
2519 linkedLlvmIRCallback(linkedLlvmIRCallback),
2520 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2521 isaCallback(isaCallback), typeID(typeID) {}
2523 TypeID TargetOptions::getTypeID() const { return typeID; }
2525 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2527 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2528 return librariesToLink;
2531 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2533 StringRef TargetOptions::getELFSection() const { return elfSection; }
2535 SymbolTable *TargetOptions::getSymbolTable() const {
2536 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2539 function_ref<void(llvm::Module &)>
2540 TargetOptions::getInitialLlvmIRCallback() const {
2541 return initialLlvmIRCallback;
2544 function_ref<void(llvm::Module &)>
2545 TargetOptions::getLinkedLlvmIRCallback() const {
2546 return linkedLlvmIRCallback;
2549 function_ref<void(llvm::Module &)>
2550 TargetOptions::getOptimizedLlvmIRCallback() const {
2551 return optimizedLlvmIRCallback;
2554 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2558 CompilationTarget TargetOptions::getCompilationTarget() const {
2559 return compilationTarget;
2562 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2563 return CompilationTarget::Fatbin;
2566 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2567 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2568 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2569 llvm::StringSaver stringSaver(options.first);
2570 StringRef opts = cmdOptions;
2571 // For a correct tokenization of the command line options `opts` must be
2572 // unquoted, otherwise the tokenization function returns a single string: the
2573 // unquoted `cmdOptions` -which is not the desired behavior.
2574 // Remove any quotes if they are at the beginning and end of the string:
2575 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2576 opts.consume_front("\
""), opts.consume_back(
"\"");
2577 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2578 opts.consume_front(
"'"), opts.consume_back(
"'");
2580 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2583 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2589 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2594 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2596 size_t startPos =
cmdOptions.find(startsWith);
2597 if (startPos == std::string::npos)
2608 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2609 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2611 #define GET_ATTRDEF_CLASSES
2612 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2614 #define GET_OP_CLASSES
2615 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2617 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1179::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....