33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/StringSaver.h"
44 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
50 int64_t GPUBlockMappingAttr::getMappingId()
const {
51 return static_cast<int64_t
>(getBlock());
54 bool GPUBlockMappingAttr::isLinearMapping()
const {
55 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
58 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
59 return isLinearMapping()
60 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
64 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
65 return static_cast<int64_t
>(getWarpgroup());
68 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
69 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
72 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
73 return isLinearMapping()
74 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
78 int64_t GPUWarpMappingAttr::getMappingId()
const {
79 return static_cast<int64_t
>(getWarp());
82 bool GPUWarpMappingAttr::isLinearMapping()
const {
83 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
86 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
87 return isLinearMapping()
88 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
92 int64_t GPUThreadMappingAttr::getMappingId()
const {
93 return static_cast<int64_t
>(getThread());
96 bool GPUThreadMappingAttr::isLinearMapping()
const {
97 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
100 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
101 return isLinearMapping()
102 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
106 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
107 return static_cast<int64_t
>(getAddressSpace());
110 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
111 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
114 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
115 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
132 elementType, operand);
146 return elementType.
isF16() || elementType.
isF32() ||
155 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
156 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
158 if (shape.size() != 2)
159 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
163 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
172 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
175 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
176 return gpuAttr.getValue() == getWorkgroupAddressSpace();
180 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
181 Attribute memorySpace = type.getMemorySpace();
182 return isWorkgroupMemoryAddressSpace(memorySpace);
185 bool GPUDialect::isKernel(
Operation *op) {
186 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
187 return static_cast<bool>(isKernelAttr);
203 void GPUDialect::initialize() {
204 addTypes<AsyncTokenType>();
205 addTypes<MMAMatrixType>();
206 addTypes<SparseDnTensorHandleType>();
207 addTypes<SparseSpMatHandleType>();
208 addTypes<SparseSpGEMMOpHandleType>();
211 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
214 #define GET_ATTRDEF_LIST
215 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
217 addInterfaces<GPUInlinerInterface>();
218 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
225 return "sparse.dntensor_handle";
227 return "sparse.spmat_handle";
229 return "sparse.spgemmop_handle";
231 llvm_unreachable(
"unknown sparse handle kind");
243 if (keyword ==
"async.token")
246 if (keyword ==
"mma_matrix") {
275 shape, elementType, operand);
293 .Case<SparseDnTensorHandleType>([&](
Type) {
296 .Case<SparseSpMatHandleType>(
298 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
304 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
307 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
309 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
314 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
317 " must be a dense i32 array");
318 if (array.size() != 3)
320 " must contain exactly 3 elements");
324 LogicalResult GPUDialect::verifyOperationAttribute(
Operation *op,
326 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
328 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
330 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
331 attr.
getName() != getContainerModuleAttrName())
334 auto module = dyn_cast<ModuleOp>(op);
337 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
338 << ModuleOp::getOperationName() <<
'\'';
340 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
343 if (!launchOp->getParentOp() ||
344 launchOp->getParentOp()->getParentOp() != module)
349 if (!launchOp->getAttrOfType<SymbolRefAttr>(
350 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
354 StringAttr kernelContainerName = launchOp.getKernelModuleName();
355 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
356 if (!kernelContainer)
358 <<
"kernel container '" << kernelContainerName.getValue()
362 if (isa<BinaryOp>(kernelContainer))
365 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
367 return launchOp.emitOpError()
368 <<
"kernel module '" << kernelContainerName.getValue()
372 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
375 << launchOp.getKernel() <<
"' is undefined";
376 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
377 if (!kernelConvertedFunction) {
379 <<
"referenced kernel '" << launchOp.getKernel()
380 <<
"' is not a function";
381 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
386 GPUDialect::getKernelFuncAttrName()))
387 return launchOp.emitOpError(
"kernel function is missing the '")
388 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
393 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
394 if (!kernelGPUFunction)
397 unsigned actualNumArguments = launchOp.getNumKernelOperands();
398 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
399 if (expectedNumArguments != actualNumArguments)
400 return launchOp.emitOpError(
"got ")
401 << actualNumArguments <<
" kernel operands but expected "
402 << expectedNumArguments;
404 auto functionType = kernelGPUFunction.getFunctionType();
405 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
406 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
407 return launchOp.emitOpError(
"type of function argument ")
408 << i <<
" does not match";
415 return walkResult.wasInterrupted() ? failure() : success();
428 return parser.
emitError(loc,
"needs to be named when marked 'async'");
443 if (asyncDependencies.empty())
448 llvm::interleaveComma(asyncDependencies, printer);
477 p <<
' ' << keyword <<
'(';
478 llvm::interleaveComma(
486 gpu::AddressSpace memorySpace) {
487 for (
Value v : attributions) {
488 auto type = llvm::dyn_cast<MemRefType>(v.getType());
490 return op->
emitOpError() <<
"expected memref type in attribution";
495 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
498 if (addressSpace.getValue() != memorySpace)
500 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
501 <<
" in attribution";
512 using Kind = gpu::AllReduceOperation;
513 if (llvm::is_contained(
514 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
516 if (!isa<FloatType>(resType))
520 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
521 Kind::AND, Kind::OR, Kind::XOR},
523 if (!isa<IntegerType>(resType))
530 LogicalResult gpu::AllReduceOp::verifyRegions() {
531 if (getBody().empty() != getOp().has_value())
532 return emitError(
"expected either an op attribute or a non-empty body");
533 if (!getBody().empty()) {
534 if (getBody().getNumArguments() != 2)
535 return emitError(
"expected two region arguments");
536 for (
auto argument : getBody().getArguments()) {
537 if (argument.getType() !=
getType())
538 return emitError(
"incorrect region argument type");
540 unsigned yieldCount = 0;
541 for (
Block &block : getBody()) {
542 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
543 if (yield.getNumOperands() != 1)
544 return emitError(
"expected one gpu.yield operand");
545 if (yield.getOperand(0).getType() !=
getType())
546 return emitError(
"incorrect gpu.yield type");
551 return emitError(
"expected gpu.yield op in region");
553 gpu::AllReduceOperation opName = *getOp();
555 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
556 <<
"` reduction operation is not compatible with type "
565 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
569 Region &body = launchOp.getBody();
570 assert(!body.
empty() &&
"Invalid region");
587 AllReduceOperationAttr &attr) {
590 std::optional<AllReduceOperation> op =
591 gpu::symbolizeAllReduceOperation(enumStr);
600 AllReduceOperationAttr attr) {
611 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
612 if (vecTy.isScalable())
613 return emitOpError() <<
"is not compatible with scalable vector types";
615 elemType = vecTy.getElementType();
618 gpu::AllReduceOperation opName = getOp();
620 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
621 <<
"` reduction operation is not compatible with type "
625 auto clusterSize = getClusterSize();
627 uint32_t size = *clusterSize;
628 if (!llvm::isPowerOf2_32(size)) {
629 return emitOpError() <<
"cluster size " << size
630 <<
" is not a power of two";
634 uint32_t stride = getClusterStride();
635 if (stride != 1 && !clusterSize) {
636 return emitOpError() <<
"cluster stride can only be specified if cluster "
639 if (!llvm::isPowerOf2_32(stride)) {
640 return emitOpError() <<
"cluster stride " << stride
641 <<
" is not a power of two";
647 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
648 if (getClusterSize() == 1)
665 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
669 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
687 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
696 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
705 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
706 getBlockSizeY, getBlockSizeZ});
713 if (dynamicSharedMemorySize)
722 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
725 for (
Type argTy : workgroupAttributions)
727 for (
Type argTy : privateAttributions)
731 segmentSizes.front() = asyncDependencies.size();
732 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
733 segmentSizes[7] = clusterSizeX ? 1 : 0;
734 segmentSizes[8] = clusterSizeY ? 1 : 0;
735 segmentSizes[9] = clusterSizeZ ? 1 : 0;
741 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
742 auto args = getBody().getArguments();
747 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
748 auto args = getBody().getArguments();
753 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
754 auto args = getBody().getArguments();
759 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
760 auto args = getBody().getArguments();
761 return KernelDim3{args[9], args[10], args[11]};
764 std::optional<KernelDim3> LaunchOp::getClusterIds() {
765 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
766 if (!hasClusterSize())
768 auto args = getBody().getArguments();
769 return KernelDim3{args[12], args[13], args[14]};
772 std::optional<KernelDim3> LaunchOp::getClusterSize() {
773 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
774 if (!hasClusterSize())
776 auto args = getBody().getArguments();
777 return KernelDim3{args[15], args[16], args[17]};
780 KernelDim3 LaunchOp::getGridSizeOperandValues() {
781 auto operands = getOperands().drop_front(getAsyncDependencies().size());
782 return KernelDim3{operands[0], operands[1], operands[2]};
785 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
786 auto operands = getOperands().drop_front(getAsyncDependencies().size());
787 return KernelDim3{operands[3], operands[4], operands[5]};
790 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
791 auto operands = getOperands().drop_front(getAsyncDependencies().size());
792 if (!hasClusterSize())
794 return KernelDim3{operands[6], operands[7], operands[8]};
798 if (!(hasClusterSize()) &&
799 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
800 return emitOpError() <<
"cluster size must be all present";
804 LogicalResult LaunchOp::verifyRegions() {
808 if (!getBody().empty()) {
809 if (getBody().getNumArguments() <
810 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
811 return emitOpError(
"unexpected number of region arguments");
816 GPUDialect::getWorkgroupAddressSpace())) ||
818 GPUDialect::getPrivateAddressSpace())))
823 for (
Block &block : getBody()) {
826 if (block.back().getNumSuccessors() != 0)
828 if (!isa<gpu::TerminatorOp>(&block.back())) {
831 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
832 "' or a terminator with successors")
833 .attachNote(getLoc())
834 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
838 if (getNumResults() == 0 && getAsyncToken())
839 return emitOpError(
"needs to be named when async keyword is specified");
850 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
851 p << size.
x <<
" = " << operands.
x <<
", ";
852 p << size.
y <<
" = " << operands.
y <<
", ";
853 p << size.
z <<
" = " << operands.
z <<
')';
857 if (getAsyncToken()) {
859 if (!getAsyncDependencies().empty())
860 p <<
" [" << getAsyncDependencies() <<
']';
863 if (hasClusterSize()) {
864 p <<
' ' << getClustersKeyword();
866 getClusterSizeOperandValues().value(),
867 getClusterIds().value());
869 p <<
' ' << getBlocksKeyword();
872 p <<
' ' << getThreadsKeyword();
875 if (getDynamicSharedMemorySize())
876 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
877 << getDynamicSharedMemorySize();
886 LaunchOp::getOperandSegmentSizeAttr(),
887 getNumWorkgroupAttributionsAttrName()});
901 assert(indices.size() == 3 &&
"space for three indices expected");
907 std::move(args.begin(), args.end(), indices.begin());
909 for (
int i = 0; i < 3; ++i) {
931 sizes(LaunchOp::kNumConfigOperands);
938 LaunchOp::kNumConfigRegionAttributes);
949 result.
types.push_back(asyncTokenType);
951 bool hasCluster =
false;
956 regionArgs.resize(18);
965 regionArgsRef.slice(15, 3),
966 regionArgsRef.slice(12, 3)))
974 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
976 regionArgsRef.slice(6, 3),
977 regionArgsRef.slice(0, 3)) ||
978 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
980 regionArgsRef.slice(9, 3),
981 regionArgsRef.slice(3, 3)) ||
987 bool hasDynamicSharedMemorySize =
false;
989 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
990 hasDynamicSharedMemorySize =
true;
1006 LaunchOp::kNumConfigRegionAttributes + 6, index);
1009 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1011 arg.
ssaName = std::get<0>(ssaValueAndType);
1012 arg.
type = std::get<1>(ssaValueAndType);
1013 regionArguments.push_back(arg);
1024 unsigned numWorkgroupAttrs = regionArguments.size() -
1025 LaunchOp::kNumConfigRegionAttributes -
1026 (hasCluster ? 6 : 0);
1027 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1044 segmentSizes.front() = asyncDependencies.size();
1047 segmentSizes[7] = 0;
1048 segmentSizes[8] = 0;
1049 segmentSizes[9] = 0;
1051 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1052 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1066 bool simplified =
false;
1067 auto constPropIdUses = [&](
Value id,
Value size) {
1071 if (
id.getUses().empty())
1078 rewriter.
create<arith::ConstantIndexOp>(op.getLoc(), 0);
1083 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1084 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1085 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1086 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1087 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1088 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1090 return success(simplified);
1102 auto attrName = getNumWorkgroupAttributionsAttrName();
1103 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1104 (*this)->setAttr(attrName,
1106 return getBody().insertArgument(
1107 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1115 return getBody().addArgument(type, loc);
1123 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1127 std::optional<KernelDim3> clusterSize) {
1128 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1129 "expected a symbol reference with a single nested reference");
1137 if (clusterSize.has_value())
1138 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1139 if (dynamicSharedMemorySize)
1144 prop.kernel = kernelSymbol;
1145 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1147 for (
auto &sz : prop.operandSegmentSizes)
1149 prop.operandSegmentSizes[0] = asyncDependencies.size();
1150 if (!clusterSize.has_value()) {
1151 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1152 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1153 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1155 prop.operandSegmentSizes[segmentSizesLen - 3] =
1156 dynamicSharedMemorySize ? 1 : 0;
1157 prop.operandSegmentSizes[segmentSizesLen - 2] =
1158 static_cast<int32_t
>(kernelOperands.size());
1159 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1167 std::optional<KernelDim3> clusterSize) {
1168 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1171 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1172 build(builder, result, kernelSymbol, gridSize,
getBlockSize,
1173 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1174 asyncDependencies, clusterSize);
1181 std::optional<KernelDim3> clusterSize) {
1185 if (clusterSize.has_value())
1186 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1187 if (dynamicSharedMemorySize)
1193 prop.kernel = kernel;
1194 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1196 for (
auto &sz : prop.operandSegmentSizes)
1198 prop.operandSegmentSizes[0] = 0;
1199 if (!clusterSize.has_value()) {
1200 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1201 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1202 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1204 prop.operandSegmentSizes[segmentSizesLen - 3] =
1205 dynamicSharedMemorySize ? 1 : 0;
1206 prop.operandSegmentSizes[segmentSizesLen - 2] =
1207 static_cast<int32_t
>(kernelOperands.size());
1208 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1211 StringAttr LaunchFuncOp::getKernelModuleName() {
1215 StringAttr LaunchFuncOp::getKernelName() {
1219 unsigned LaunchFuncOp::getNumKernelOperands() {
1220 return getKernelOperands().size();
1223 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1224 return getKernelOperands()[i];
1227 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1228 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1229 return KernelDim3{operands[0], operands[1], operands[2]};
1232 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1233 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1234 return KernelDim3{operands[3], operands[4], operands[5]};
1237 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1238 assert(hasClusterSize() &&
1239 "cluster size is not set, check hasClusterSize() first");
1240 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1241 return KernelDim3{operands[6], operands[7], operands[8]};
1245 auto module = (*this)->getParentOfType<ModuleOp>();
1247 return emitOpError(
"expected to belong to a module");
1249 if (!module->getAttrOfType<UnitAttr>(
1250 GPUDialect::getContainerModuleAttrName()))
1251 return emitOpError(
"expected the closest surrounding module to have the '" +
1252 GPUDialect::getContainerModuleAttrName() +
1255 if (hasClusterSize()) {
1256 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1258 return emitOpError()
1259 <<
"expects types of the cluster dimensions must be the same";
1267 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1268 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1275 if (clusterValue.has_value()) {
1276 clusterXTy = clusterYTy = clusterZTy = dimTy;
1283 Type clusterYTy,
Type clusterZTy) {
1285 printer <<
": " << dimTy;
1295 auto parseElement = [&]() -> ParseResult {
1296 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1301 parseElement,
" in argument list");
1306 if (operands.empty())
1309 llvm::interleaveComma(llvm::zip(operands, types), printer,
1310 [&](
const auto &pair) {
1323 int32_t offset, int32_t width, ShuffleMode mode) {
1324 build(builder, result, value,
1339 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1341 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1352 results.
add(eraseRedundantGpuBarrierOps);
1362 auto attrName = getNumWorkgroupAttributionsAttrName();
1363 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1364 (*this)->setAttr(attrName,
1366 return getBody().insertArgument(
1367 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1375 return getBody().addArgument(type, loc);
1379 StringRef name, FunctionType type,
1389 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1396 for (
Type argTy : type.getInputs())
1398 for (
Type argTy : workgroupAttributions)
1400 for (
Type argTy : privateAttributions)
1419 size_t existingArgs = args.size();
1420 ParseResult result =
1426 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1431 attributionAttrs =
nullptr;
1437 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1438 if (!argument.attrs)
1441 attributionAttrsVec.push_back(argument.attrs);
1443 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1459 StringAttr nameAttr;
1466 parser,
false, entryArgs, isVariadic, resultTypes,
1470 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1471 return parser.
emitError(signatureLocation)
1472 <<
"gpu.func requires named arguments";
1479 for (
auto &arg : entryArgs)
1480 argTypes.push_back(arg.
type);
1486 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1487 getResAttrsAttrName(result.
name));
1492 entryArgs, workgroupAttributionAttrs)))
1497 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1498 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1500 if (workgroupAttributionAttrs)
1501 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1502 workgroupAttributionAttrs);
1507 entryArgs, privateAttributionAttrs)))
1509 if (privateAttributionAttrs)
1511 privateAttributionAttrs);
1515 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1530 ArrayAttr attributes) {
1534 p <<
' ' << keyword <<
'(';
1535 llvm::interleaveComma(
1538 p << v <<
" : " << v.
getType();
1540 size_t attributionIndex = pair.index();
1541 DictionaryAttr attrs;
1542 if (attributes && attributionIndex < attributes.size())
1543 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1554 FunctionType type = getFunctionType();
1560 getWorkgroupAttribAttrs().value_or(
nullptr));
1562 getPrivateAttribAttrs().value_or(
nullptr));
1564 p <<
' ' << getKernelKeyword();
1568 {getNumWorkgroupAttributionsAttrName(),
1569 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1570 getArgAttrsAttrName(), getResAttrsAttrName(),
1571 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1577 StringAttr attrName) {
1578 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1579 if (!allAttrs || index >= allAttrs.size())
1580 return DictionaryAttr();
1581 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1584 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1588 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1593 DictionaryAttr value, StringAttr attrName) {
1595 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1598 elements.append(allAttrs.begin(), allAttrs.end());
1599 while (elements.size() <= index)
1604 elements[index] = value;
1606 op->setAttr(attrName, newValue);
1609 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1610 DictionaryAttr value) {
1614 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1615 DictionaryAttr value) {
1620 StringAttr name, StringAttr attrsName) {
1624 return dict.get(name);
1627 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1629 assert(index < getNumWorkgroupAttributions() &&
1630 "index must map to a workgroup attribution");
1632 getWorkgroupAttribAttrsAttrName());
1635 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1637 assert(index < getNumPrivateAttributions() &&
1638 "index must map to a private attribution");
1640 getPrivateAttribAttrsAttrName());
1644 Attribute value, StringAttr attrsName) {
1649 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1652 bool mustSort =
true;
1653 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1654 if (elems[i].getName() == name) {
1657 std::swap(elems[i], elems[elems.size() - 1]);
1669 elems.emplace_back(name, value);
1672 DictionaryAttr::sortInPlace(elems);
1674 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1678 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1680 assert(index < getNumWorkgroupAttributions() &&
1681 "index must map to a workgroup attribution");
1683 getWorkgroupAttribAttrsAttrName());
1686 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1688 assert(index < getNumPrivateAttributions() &&
1689 "index must map to a private attribution");
1691 getPrivateAttribAttrsAttrName());
1694 LogicalResult GPUFuncOp::verifyType() {
1695 if (isKernel() && getFunctionType().getNumResults() != 0)
1696 return emitOpError() <<
"expected void return type for kernel function";
1702 LogicalResult GPUFuncOp::verifyBody() {
1704 return emitOpError() <<
"expected body with at least one block";
1705 unsigned numFuncArguments = getNumArguments();
1706 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1707 unsigned numBlockArguments = front().getNumArguments();
1708 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1709 return emitOpError() <<
"expected at least "
1710 << numFuncArguments + numWorkgroupAttributions
1711 <<
" arguments to body region";
1714 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1715 Type blockArgType = front().getArgument(i).getType();
1716 if (funcArgTypes[i] != blockArgType)
1717 return emitOpError() <<
"expected body region argument #" << i
1718 <<
" to be of type " << funcArgTypes[i] <<
", got "
1723 GPUDialect::getWorkgroupAddressSpace())) ||
1725 GPUDialect::getPrivateAddressSpace())))
1736 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1738 FunctionType funType =
function.getFunctionType();
1740 if (funType.getNumResults() != getOperands().size())
1741 return emitOpError()
1742 .append(
"expected ", funType.getNumResults(),
" result operands")
1743 .attachNote(
function.getLoc())
1744 .append(
"return type declared here");
1747 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1748 auto [type, operand] = pair.value();
1749 if (type != operand.getType())
1750 return emitOpError() <<
"unexpected type `" << operand.getType()
1751 <<
"' for operand #" << pair.index();
1761 StringRef name, ArrayAttr targets,
1766 props.targets = targets;
1768 props.offloadingHandler = offloadingHandler;
1774 build(builder, result, name,
1775 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1779 bool GPUModuleOp::hasTarget(
Attribute target) {
1780 if (ArrayAttr targets = getTargetsAttr())
1781 return llvm::count(targets.getValue(), target);
1786 ArrayAttr &targetsAttr = getProperties().targets;
1795 Attribute offloadingHandler, ArrayAttr objects) {
1799 properties.objects = objects;
1800 if (offloadingHandler)
1801 properties.offloadingHandler = offloadingHandler;
1803 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1808 build(builder, result, name, offloadingHandler,
1809 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1820 if (!offloadingHandler)
1828 printer << '<' << offloadingHandler << '>
';
1831 //===----------------------------------------------------------------------===//
1833 //===----------------------------------------------------------------------===//
1835 LogicalResult MemcpyOp::verify() {
1836 auto srcType = getSrc().getType();
1837 auto dstType = getDst().getType();
1839 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1840 return emitOpError("arguments have incompatible element type");
1842 if (failed(verifyCompatibleShape(srcType, dstType)))
1843 return emitOpError("arguments have incompatible shape");
1852 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1853 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1855 LogicalResult matchAndRewrite(MemcpyOp op,
1856 PatternRewriter &rewriter) const override {
1857 Value dest = op.getDst();
1858 Operation *destDefOp = dest.getDefiningOp();
1859 // `dest` must be defined by an op having Allocate memory effect in order to
1860 // perform the folding.
1862 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1864 // We can erase `op` iff `dest` has no other use apart from its
1865 // use by `op` and dealloc ops.
1866 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1867 return user != op &&
1868 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1871 // We can perform the folding if and only if op has a single async
1872 // dependency and produces an async token as result, or if it does not have
1873 // any async dependency and does not produce any async token result.
1874 if (op.getAsyncDependencies().size() > 1 ||
1875 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1876 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1878 rewriter.replaceOp(op, op.getAsyncDependencies());
1883 } // end anonymous namespace
1885 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1886 MLIRContext *context) {
1887 results.add<EraseTrivialCopyOp>(context);
1890 //===----------------------------------------------------------------------===//
1891 // GPU_SubgroupMmaLoadMatrixOp
1892 //===----------------------------------------------------------------------===//
1894 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1895 auto srcType = getSrcMemref().getType();
1896 auto resType = getRes().getType();
1897 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1898 auto operand = resMatrixType.getOperand();
1899 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1901 if (!isLastMemrefDimUnitStride(srcMemrefType))
1903 "expected source memref most minor dim must have unit stride");
1905 if (operand != "AOp" && operand != "BOp" && operand != "COp")
1906 return emitError("only AOp, BOp and COp can be loaded");
1911 //===----------------------------------------------------------------------===//
1912 // GPU_SubgroupMmaStoreMatrixOp
1913 //===----------------------------------------------------------------------===//
1915 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1916 auto srcType = getSrc().getType();
1917 auto dstType = getDstMemref().getType();
1918 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1919 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1921 if (!isLastMemrefDimUnitStride(dstMemrefType))
1923 "expected destination memref most minor dim must have unit stride");
1925 if (srcMatrixType.getOperand() != "COp")
1927 "expected the operand matrix being stored to have 'COp
' operand type");
1932 //===----------------------------------------------------------------------===//
1933 // GPU_SubgroupMmaComputeOp
1934 //===----------------------------------------------------------------------===//
1936 LogicalResult SubgroupMmaComputeOp::verify() {
1937 enum OperandMap { A, B, C };
1938 SmallVector<MMAMatrixType, 3> opTypes;
1939 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1940 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1941 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1943 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1944 opTypes[C].getOperand() != "COp")
1945 return emitError("operands must be in the order AOp, BOp, COp");
1947 ArrayRef<int64_t> aShape, bShape, cShape;
1948 aShape = opTypes[A].getShape();
1949 bShape = opTypes[B].getShape();
1950 cShape = opTypes[C].getShape();
1952 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1953 bShape[1] != cShape[1])
1954 return emitError("operand shapes do not satisfy matmul constraints");
1959 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1960 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1961 return memref::foldMemRefCast(*this);
1964 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1965 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1966 return memref::foldMemRefCast(*this);
1969 //===----------------------------------------------------------------------===//
1971 //===----------------------------------------------------------------------===//
1978 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1980 using OpRewritePattern::OpRewritePattern;
1982 LogicalResult matchAndRewrite(WaitOp op,
1983 PatternRewriter &rewriter) const final {
1984 auto predicate = [](Value value) {
1985 auto waitOp = value.getDefiningOp<WaitOp>();
1986 return waitOp && waitOp->getNumOperands() == 0;
1988 if (llvm::none_of(op.getAsyncDependencies(), predicate))
1990 SmallVector<Value> validOperands;
1991 for (Value operand : op->getOperands()) {
1992 if (predicate(operand))
1994 validOperands.push_back(operand);
1996 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2008 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2010 using OpRewritePattern::OpRewritePattern;
2012 LogicalResult matchAndRewrite(WaitOp op,
2013 PatternRewriter &rewriter) const final {
2014 // Erase gpu.wait ops that neither have any async dependencies nor return
2016 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2017 rewriter.eraseOp(op);
2020 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2021 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2022 op.getAsyncToken()) {
2023 rewriter.replaceOp(op, op.getAsyncDependencies());
2026 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2027 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2028 rewriter.eraseOp(op);
2035 } // end anonymous namespace
2037 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2038 MLIRContext *context) {
2039 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2042 //===----------------------------------------------------------------------===//
2044 //===----------------------------------------------------------------------===//
2046 LogicalResult AllocOp::verify() {
2047 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2049 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2050 return emitOpError("dimension operand count does not equal memref "
2051 "dynamic dimension count");
2053 unsigned numSymbols = 0;
2054 if (!memRefType.getLayout().isIdentity())
2055 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2056 if (getSymbolOperands().size() != numSymbols) {
2058 "symbol operand count does not equal memref symbol count");
2068 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2069 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2071 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2072 PatternRewriter &rewriter) const override {
2073 std::optional<int64_t> index = dimOp.getConstantIndex();
2077 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2078 if (!memrefType || index.value() >= memrefType.getRank() ||
2079 !memrefType.isDynamicDim(index.value()))
2082 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2086 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2087 memrefType.getDynamicDimIndex(index.value()));
2088 rewriter.replaceOp(dimOp, substituteOp);
2095 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2096 MLIRContext *context) {
2097 results.add<SimplifyDimOfAllocOp>(context);
2100 //===----------------------------------------------------------------------===//
2101 // GPU object attribute
2102 //===----------------------------------------------------------------------===//
2104 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2105 Attribute target, CompilationTarget format,
2106 StringAttr object, DictionaryAttr properties,
2107 KernelTableAttr kernels) {
2109 return emitError() << "the target attribute cannot be null";
2110 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2112 return emitError() << "the target attribute must implement or promise the "
2113 "`gpu::TargetAttrInterface`";
2117 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2118 StringAttr &object) {
2119 std::optional<CompilationTarget> formatResult;
2120 StringRef enumKeyword;
2121 auto loc = odsParser.getCurrentLocation();
2122 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2123 formatResult = CompilationTarget::Fatbin;
2124 if (!formatResult &&
2126 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2127 odsParser.parseEqual())
2128 return odsParser.emitError(loc, "expected an equal sign");
2130 return odsParser.emitError(loc, "expected keyword for GPU object format");
2131 FailureOr<StringAttr> objectResult =
2132 FieldParser<StringAttr>::parse(odsParser);
2133 if (failed(objectResult))
2134 return odsParser.emitError(odsParser.getCurrentLocation(),
2135 "failed to parse GPU_ObjectAttr parameter "
2136 "'
object' which is to be a `StringAttr`");
2137 format = *formatResult;
2138 object = *objectResult;
2142 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2143 StringAttr object) {
2144 if (format != CompilationTarget::Fatbin)
2145 odsParser << stringifyEnum(format) << " = ";
2146 odsParser << object;
2150 //===----------------------------------------------------------------------===//
2151 // GPU select object attribute
2152 //===----------------------------------------------------------------------===//
2155 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2157 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2159 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2160 if (intAttr.getInt() < 0) {
2161 return emitError() << "the object index must be positive";
2163 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2165 << "the target attribute must be a GPU Target attribute";
2171 //===----------------------------------------------------------------------===//
2172 // DynamicSharedMemoryOp
2173 //===----------------------------------------------------------------------===//
2175 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2176 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2177 return emitOpError() << "must be inside an op with symbol table";
2179 MemRefType memrefType = getResultMemref().getType();
2180 // Check address space
2181 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2182 return emitOpError() << "address space must be "
2183 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2184 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2186 if (memrefType.hasStaticShape()) {
2187 return emitOpError() << "result memref type must be memref<?xi8, "
2188 "#gpu.address_space<workgroup>>";
2193 //===----------------------------------------------------------------------===//
2194 // GPU WarpExecuteOnLane0Op
2195 //===----------------------------------------------------------------------===//
2197 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2198 p << "(" << getLaneid() << ")";
2200 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2201 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2202 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2204 if (!getArgs().empty())
2205 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2206 if (!getResults().empty())
2207 p << " -> (" << getResults().getTypes() << ')
';
2209 p.printRegion(getRegion(),
2210 /*printEntryBlockArgs=*/true,
2211 /*printBlockTerminators=*/!getResults().empty());
2212 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2215 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2216 OperationState &result) {
2217 // Create the region.
2218 result.regions.reserve(1);
2219 Region *warpRegion = result.addRegion();
2221 auto &builder = parser.getBuilder();
2222 OpAsmParser::UnresolvedOperand laneId;
2224 // Parse predicate operand.
2225 if (parser.parseLParen() ||
2226 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2227 parser.parseRParen())
2231 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2232 parser.parseRSquare())
2234 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2235 builder.getContext())),
2236 builder.getI64IntegerAttr(warpSize));
2238 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2241 llvm::SMLoc inputsOperandsLoc;
2242 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2243 SmallVector<Type> inputTypes;
2244 if (succeeded(parser.parseOptionalKeyword("args"))) {
2245 if (parser.parseLParen())
2248 inputsOperandsLoc = parser.getCurrentLocation();
2249 if (parser.parseOperandList(inputsOperands) ||
2250 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2253 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2257 // Parse optional results type list.
2258 if (parser.parseOptionalArrowTypeList(result.types))
2260 // Parse the region.
2261 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2264 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2266 // Parse the optional attribute list.
2267 if (parser.parseOptionalAttrDict(result.attributes))
2272 void WarpExecuteOnLane0Op::getSuccessorRegions(
2273 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2274 if (!point.isParent()) {
2275 regions.push_back(RegionSuccessor(getResults()));
2279 // The warp region is always executed
2280 regions.push_back(RegionSuccessor(&getWarpRegion()));
2283 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2284 TypeRange resultTypes, Value laneId,
2286 build(builder, result, resultTypes, laneId, warpSize,
2287 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2290 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2291 TypeRange resultTypes, Value laneId,
2292 int64_t warpSize, ValueRange args,
2293 TypeRange blockArgTypes) {
2294 result.addOperands(laneId);
2295 result.addAttribute(getAttributeNames()[0],
2296 builder.getI64IntegerAttr(warpSize));
2297 result.addTypes(resultTypes);
2298 result.addOperands(args);
2299 assert(args.size() == blockArgTypes.size());
2300 OpBuilder::InsertionGuard guard(builder);
2301 Region *warpRegion = result.addRegion();
2302 Block *block = builder.createBlock(warpRegion);
2303 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2304 block->addArgument(type, arg.getLoc());
2309 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2310 int64_t warpSize, Operation *op) {
2311 // If the types matches there is no distribution.
2312 if (expanded == distributed)
2314 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2315 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2316 if (!expandedVecType || !distributedVecType)
2317 return op->emitOpError("expected vector type for distributed operands.");
2318 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2319 expandedVecType.getElementType() != distributedVecType.getElementType())
2320 return op->emitOpError(
2321 "expected distributed vectors to have same rank and element type.");
2323 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2324 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2325 int64_t eDim = expandedVecType.getDimSize(i);
2326 int64_t dDim = distributedVecType.getDimSize(i);
2329 if (eDim % dDim != 0)
2330 return op->emitOpError()
2331 << "expected expanded vector dimension #" << i << " (" << eDim
2332 << ") to be a multipler of the distributed vector dimension ("
2334 scales[i] = eDim / dDim;
2336 if (std::accumulate(scales.begin(), scales.end(), 1,
2337 std::multiplies<int64_t>()) != warpSize)
2338 return op->emitOpError()
2339 << "incompatible distribution dimensions from " << expandedVecType
2340 << " to " << distributedVecType << " with warp size = " << warpSize;
2345 LogicalResult WarpExecuteOnLane0Op::verify() {
2346 if (getArgs().size() != getWarpRegion().getNumArguments())
2348 "expected same number op arguments and block arguments.");
2350 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2351 if (yield.getNumOperands() != getNumResults())
2353 "expected same number of yield operands and return values.");
2354 int64_t warpSize = getWarpSize();
2355 for (auto [regionArg, arg] :
2356 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2357 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2358 warpSize, getOperation())))
2361 for (auto [yieldOperand, result] :
2362 llvm::zip_equal(yield.getOperands(), getResults())) {
2363 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2364 warpSize, getOperation())))
2369 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2371 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2374 //===----------------------------------------------------------------------===//
2375 // GPU KernelMetadataAttr
2376 //===----------------------------------------------------------------------===//
2378 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2379 DictionaryAttr metadata) {
2380 assert(kernel && "invalid kernel");
2381 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2382 kernel.getAllArgAttrs(), metadata);
2386 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2387 FunctionOpInterface kernel,
2388 DictionaryAttr metadata) {
2389 assert(kernel && "invalid kernel");
2390 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2391 kernel.getAllArgAttrs(), metadata);
2395 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2398 NamedAttrList attrList;
2399 if (DictionaryAttr dict = getMetadata())
2400 attrList.append(dict);
2401 attrList.append(attrs);
2402 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2403 attrList.getDictionary(getContext()));
2407 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2408 StringAttr name, Type functionType,
2409 ArrayAttr argAttrs, DictionaryAttr metadata) {
2411 return emitError() << "the kernel name can't be empty
";
2413 if (llvm::any_of(argAttrs, [](Attribute attr) {
2414 return !llvm::isa<DictionaryAttr>(attr);
2417 << "all attributes in the array must be a dictionary attribute
";
2422 //===----------------------------------------------------------------------===//
2423 // GPU KernelTableAttr
2424 //===----------------------------------------------------------------------===//
2426 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2427 ArrayRef<KernelMetadataAttr> kernels,
2429 // Note that `is_sorted` is always only invoked once even with assertions ON.
2430 assert((!isSorted || llvm::is_sorted(kernels)) &&
2431 "expected a sorted kernel array
");
2432 // Immediately return the attribute if the array is sorted.
2433 if (isSorted || llvm::is_sorted(kernels))
2434 return Base::get(context, kernels);
2436 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2437 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2438 return Base::get(context, kernelsTmp);
2441 KernelTableAttr KernelTableAttr::getChecked(
2442 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2443 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2444 // Note that `is_sorted` is always only invoked once even with assertions ON.
2445 assert((!isSorted || llvm::is_sorted(kernels)) &&
2446 "expected a sorted kernel array
");
2447 // Immediately return the attribute if the array is sorted.
2448 if (isSorted || llvm::is_sorted(kernels))
2449 return Base::getChecked(emitError, context, kernels);
2451 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2452 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2453 return Base::getChecked(emitError, context, kernelsTmp);
2457 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2458 ArrayRef<KernelMetadataAttr> kernels) {
2459 if (kernels.size() < 2)
2461 // Check that the kernels are uniquely named.
2462 if (std::adjacent_find(kernels.begin(), kernels.end(),
2463 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2464 return l.getName() == r.getName();
2465 }) != kernels.end()) {
2466 return emitError() << "expected all kernels to be uniquely named
";
2471 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2472 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2473 return found ? *iterator : KernelMetadataAttr();
2476 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2477 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2478 return found ? *iterator : KernelMetadataAttr();
2481 //===----------------------------------------------------------------------===//
2482 // GPU target options
2483 //===----------------------------------------------------------------------===//
2485 TargetOptions::TargetOptions(
2486 StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2487 StringRef cmdOptions, CompilationTarget compilationTarget,
2488 function_ref<SymbolTable *()> getSymbolTableCallback,
2489 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2490 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2491 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2492 function_ref<void(StringRef)> isaCallback)
2493 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2494 cmdOptions, compilationTarget, getSymbolTableCallback,
2495 initialLlvmIRCallback, linkedLlvmIRCallback,
2496 optimizedLlvmIRCallback, isaCallback) {}
2498 TargetOptions::TargetOptions(
2499 TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2500 StringRef cmdOptions, CompilationTarget compilationTarget,
2501 function_ref<SymbolTable *()> getSymbolTableCallback,
2502 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2503 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2504 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2505 function_ref<void(StringRef)> isaCallback)
2506 : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2507 cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2508 getSymbolTableCallback(getSymbolTableCallback),
2509 initialLlvmIRCallback(initialLlvmIRCallback),
2510 linkedLlvmIRCallback(linkedLlvmIRCallback),
2511 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2512 isaCallback(isaCallback), typeID(typeID) {}
2514 TypeID TargetOptions::getTypeID() const { return typeID; }
2516 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2518 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2520 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2522 SymbolTable *TargetOptions::getSymbolTable() const {
2523 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2526 function_ref<void(llvm::Module &)>
2527 TargetOptions::getInitialLlvmIRCallback() const {
2528 return initialLlvmIRCallback;
2531 function_ref<void(llvm::Module &)>
2532 TargetOptions::getLinkedLlvmIRCallback() const {
2533 return linkedLlvmIRCallback;
2536 function_ref<void(llvm::Module &)>
2537 TargetOptions::getOptimizedLlvmIRCallback() const {
2538 return optimizedLlvmIRCallback;
2541 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2545 CompilationTarget TargetOptions::getCompilationTarget() const {
2546 return compilationTarget;
2549 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2550 return CompilationTarget::Fatbin;
2553 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2554 TargetOptions::tokenizeCmdOptions() const {
2555 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2556 llvm::StringSaver stringSaver(options.first);
2557 StringRef opts = cmdOptions;
2558 // For a correct tokenization of the command line options `opts` must be
2559 // unquoted, otherwise the tokenization function returns a single string: the
2560 // unquoted `cmdOptions` -which is not the desired behavior.
2561 // Remove any quotes if they are at the beginning and end of the string:
2562 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2563 opts.consume_front("\
""), opts.consume_back(
"\"");
2564 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2565 opts.consume_front(
"'"), opts.consume_back(
"'");
2567 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2570 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2578 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2579 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2581 #define GET_ATTRDEF_CLASSES
2582 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2584 #define GET_OP_CLASSES
2585 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2587 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....