34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/ErrorHandling.h"
38 #include "llvm/Support/FormatVariadic.h"
39 #include "llvm/Support/InterleavedRange.h"
40 #include "llvm/Support/StringSaver.h"
47 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
53 int64_t GPUBlockMappingAttr::getMappingId()
const {
54 return static_cast<int64_t
>(getBlock());
57 bool GPUBlockMappingAttr::isLinearMapping()
const {
58 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
61 int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
62 return isLinearMapping()
63 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
67 int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
68 return static_cast<int64_t
>(getWarpgroup());
71 bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
72 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
75 int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
76 return isLinearMapping()
77 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
81 int64_t GPUWarpMappingAttr::getMappingId()
const {
82 return static_cast<int64_t
>(getWarp());
85 bool GPUWarpMappingAttr::isLinearMapping()
const {
86 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
89 int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
90 return isLinearMapping()
91 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
95 int64_t GPUThreadMappingAttr::getMappingId()
const {
96 return static_cast<int64_t
>(getThread());
99 bool GPUThreadMappingAttr::isLinearMapping()
const {
100 return getMappingId() >=
static_cast<int64_t
>(MappingId::LinearDim0);
103 int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
104 return isLinearMapping()
105 ? getMappingId() -
static_cast<int64_t
>(MappingId::LinearDim0)
109 int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
110 return static_cast<int64_t
>(getAddressSpace());
113 bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
114 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
117 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
118 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
135 elementType, operand);
149 return elementType.
isF16() || elementType.
isF32() ||
158 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
159 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
161 if (shape.size() != 2)
162 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
166 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
175 bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
178 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
179 return gpuAttr.getValue() == getWorkgroupAddressSpace();
183 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
184 Attribute memorySpace = type.getMemorySpace();
185 return isWorkgroupMemoryAddressSpace(memorySpace);
188 bool GPUDialect::isKernel(
Operation *op) {
189 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
190 return static_cast<bool>(isKernelAttr);
206 void GPUDialect::initialize() {
207 addTypes<AsyncTokenType>();
208 addTypes<MMAMatrixType>();
209 addTypes<SparseDnTensorHandleType>();
210 addTypes<SparseSpMatHandleType>();
211 addTypes<SparseSpGEMMOpHandleType>();
214 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
217 #define GET_ATTRDEF_LIST
218 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
220 addInterfaces<GPUInlinerInterface>();
221 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
223 declarePromisedInterfaces<
224 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
225 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
226 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
232 return "sparse.dntensor_handle";
234 return "sparse.spmat_handle";
236 return "sparse.spgemmop_handle";
238 llvm_unreachable(
"unknown sparse handle kind");
250 if (keyword ==
"async.token")
253 if (keyword ==
"mma_matrix") {
282 shape, elementType, operand);
300 .Case<SparseDnTensorHandleType>([&](
Type) {
303 .Case<SparseSpMatHandleType>(
305 .Case<SparseSpGEMMOpHandleType>([&](
Type) {
311 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
314 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
316 .Default([](
Type) { llvm_unreachable(
"unexpected 'gpu' type kind"); });
321 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
324 " must be a dense i32 array");
325 if (array.size() != 3)
327 " must contain exactly 3 elements");
331 LogicalResult GPUDialect::verifyOperationAttribute(
Operation *op,
333 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
335 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
337 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
338 attr.
getName() != getContainerModuleAttrName())
341 auto module = dyn_cast<ModuleOp>(op);
344 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
345 << ModuleOp::getOperationName() <<
'\'';
347 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) ->
WalkResult {
350 if (!launchOp->getParentOp() ||
351 launchOp->getParentOp()->getParentOp() != module)
356 if (!launchOp->getAttrOfType<SymbolRefAttr>(
357 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
361 StringAttr kernelContainerName = launchOp.getKernelModuleName();
362 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
363 if (!kernelContainer)
365 <<
"kernel container '" << kernelContainerName.getValue()
369 if (isa<BinaryOp>(kernelContainer))
372 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
374 return launchOp.emitOpError()
375 <<
"kernel module '" << kernelContainerName.getValue()
379 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
382 << launchOp.getKernel() <<
"' is undefined";
383 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
384 if (!kernelConvertedFunction) {
386 <<
"referenced kernel '" << launchOp.getKernel()
387 <<
"' is not a function";
388 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
393 GPUDialect::getKernelFuncAttrName()))
394 return launchOp.emitOpError(
"kernel function is missing the '")
395 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
400 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
401 if (!kernelGPUFunction)
404 unsigned actualNumArguments = launchOp.getNumKernelOperands();
405 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
406 if (expectedNumArguments != actualNumArguments)
407 return launchOp.emitOpError(
"got ")
408 << actualNumArguments <<
" kernel operands but expected "
409 << expectedNumArguments;
411 auto functionType = kernelGPUFunction.getFunctionType();
412 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
413 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
414 return launchOp.emitOpError(
"type of function argument ")
415 << i <<
" does not match";
422 return walkResult.wasInterrupted() ? failure() : success();
435 return parser.
emitError(loc,
"needs to be named when marked 'async'");
450 if (asyncDependencies.empty())
454 printer << llvm::interleaved_array(asyncDependencies);
483 return llvm::formatv(
"{} : {}", v, v.getType());
485 p <<
' ' << keyword <<
'('
486 << llvm::interleaved(llvm::map_range(values, printBlockArg)) <<
')';
492 gpu::AddressSpace memorySpace) {
493 for (
Value v : attributions) {
494 auto type = llvm::dyn_cast<MemRefType>(v.getType());
496 return op->
emitOpError() <<
"expected memref type in attribution";
501 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
504 if (addressSpace.getValue() != memorySpace)
506 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
507 <<
" in attribution";
518 using Kind = gpu::AllReduceOperation;
519 if (llvm::is_contained(
520 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
522 if (!isa<FloatType>(resType))
526 if (llvm::is_contained({Kind::MINSI,
Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
527 Kind::AND, Kind::OR, Kind::XOR},
529 if (!isa<IntegerType>(resType))
536 LogicalResult gpu::AllReduceOp::verifyRegions() {
537 if (getBody().empty() != getOp().has_value())
538 return emitError(
"expected either an op attribute or a non-empty body");
539 if (!getBody().empty()) {
540 if (getBody().getNumArguments() != 2)
541 return emitError(
"expected two region arguments");
542 for (
auto argument : getBody().getArguments()) {
543 if (argument.getType() !=
getType())
544 return emitError(
"incorrect region argument type");
546 unsigned yieldCount = 0;
547 for (
Block &block : getBody()) {
548 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
549 if (yield.getNumOperands() != 1)
550 return emitError(
"expected one gpu.yield operand");
551 if (yield.getOperand(0).getType() !=
getType())
552 return emitError(
"incorrect gpu.yield type");
557 return emitError(
"expected gpu.yield op in region");
559 gpu::AllReduceOperation opName = *getOp();
561 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
562 <<
"` reduction operation is not compatible with type "
571 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
575 Region &body = launchOp.getBody();
576 assert(!body.
empty() &&
"Invalid region");
593 AllReduceOperationAttr &attr) {
596 std::optional<AllReduceOperation> op =
597 gpu::symbolizeAllReduceOperation(enumStr);
606 AllReduceOperationAttr attr) {
617 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
618 if (vecTy.isScalable())
619 return emitOpError() <<
"is not compatible with scalable vector types";
621 elemType = vecTy.getElementType();
624 gpu::AllReduceOperation opName = getOp();
626 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
627 <<
"` reduction operation is not compatible with type "
631 auto clusterSize = getClusterSize();
633 uint32_t size = *clusterSize;
634 if (!llvm::isPowerOf2_32(size)) {
635 return emitOpError() <<
"cluster size " << size
636 <<
" is not a power of two";
640 uint32_t stride = getClusterStride();
641 if (stride != 1 && !clusterSize) {
642 return emitOpError() <<
"cluster stride can only be specified if cluster "
645 if (!llvm::isPowerOf2_32(stride)) {
646 return emitOpError() <<
"cluster stride " << stride
647 <<
" is not a power of two";
653 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
654 if (getClusterSize() == 1)
671 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
675 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
693 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
702 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
711 result.
addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
712 getBlockSizeY, getBlockSizeZ});
719 if (dynamicSharedMemorySize)
728 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
731 for (
Type argTy : workgroupAttributions)
733 for (
Type argTy : privateAttributions)
737 segmentSizes.front() = asyncDependencies.size();
738 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
739 segmentSizes[7] = clusterSizeX ? 1 : 0;
740 segmentSizes[8] = clusterSizeY ? 1 : 0;
741 segmentSizes[9] = clusterSizeZ ? 1 : 0;
747 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
748 auto args = getBody().getArguments();
753 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
754 auto args = getBody().getArguments();
759 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
760 auto args = getBody().getArguments();
765 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
766 auto args = getBody().getArguments();
767 return KernelDim3{args[9], args[10], args[11]};
770 std::optional<KernelDim3> LaunchOp::getClusterIds() {
771 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
772 if (!hasClusterSize())
774 auto args = getBody().getArguments();
775 return KernelDim3{args[12], args[13], args[14]};
778 std::optional<KernelDim3> LaunchOp::getClusterSize() {
779 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
780 if (!hasClusterSize())
782 auto args = getBody().getArguments();
783 return KernelDim3{args[15], args[16], args[17]};
786 KernelDim3 LaunchOp::getGridSizeOperandValues() {
787 auto operands = getOperands().drop_front(getAsyncDependencies().size());
788 return KernelDim3{operands[0], operands[1], operands[2]};
791 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
792 auto operands = getOperands().drop_front(getAsyncDependencies().size());
793 return KernelDim3{operands[3], operands[4], operands[5]};
796 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
797 auto operands = getOperands().drop_front(getAsyncDependencies().size());
798 if (!hasClusterSize())
800 return KernelDim3{operands[6], operands[7], operands[8]};
804 if (!(hasClusterSize()) &&
805 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
806 return emitOpError() <<
"cluster size must be all present";
810 LogicalResult LaunchOp::verifyRegions() {
814 if (!getBody().empty()) {
815 if (getBody().getNumArguments() <
816 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
817 return emitOpError(
"unexpected number of region arguments");
822 GPUDialect::getWorkgroupAddressSpace())) ||
824 GPUDialect::getPrivateAddressSpace())))
829 for (
Block &block : getBody()) {
832 if (block.back().getNumSuccessors() != 0)
834 if (!isa<gpu::TerminatorOp>(&block.back())) {
837 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
838 "' or a terminator with successors")
839 .attachNote(getLoc())
840 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
844 if (getNumResults() == 0 && getAsyncToken())
845 return emitOpError(
"needs to be named when async keyword is specified");
856 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
857 p << size.
x <<
" = " << operands.
x <<
", ";
858 p << size.
y <<
" = " << operands.
y <<
", ";
859 p << size.
z <<
" = " << operands.
z <<
')';
863 if (getAsyncToken()) {
865 if (!getAsyncDependencies().empty())
866 p <<
" [" << getAsyncDependencies() <<
']';
869 if (hasClusterSize()) {
870 p <<
' ' << getClustersKeyword();
872 getClusterSizeOperandValues().value(),
873 getClusterIds().value());
875 p <<
' ' << getBlocksKeyword();
878 p <<
' ' << getThreadsKeyword();
881 if (getDynamicSharedMemorySize())
882 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
883 << getDynamicSharedMemorySize();
892 LaunchOp::getOperandSegmentSizeAttr(),
893 getNumWorkgroupAttributionsAttrName()});
907 assert(indices.size() == 3 &&
"space for three indices expected");
913 std::move(args.begin(), args.end(), indices.begin());
915 for (
int i = 0; i < 3; ++i) {
937 sizes(LaunchOp::kNumConfigOperands);
941 LaunchOp::kNumConfigRegionAttributes);
952 result.
types.push_back(asyncTokenType);
954 bool hasCluster =
false;
959 regionArgs.resize(18);
968 regionArgsRef.slice(15, 3),
969 regionArgsRef.slice(12, 3)))
977 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
979 regionArgsRef.slice(6, 3),
980 regionArgsRef.slice(0, 3)) ||
981 parser.
parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
983 regionArgsRef.slice(9, 3),
984 regionArgsRef.slice(3, 3)) ||
990 bool hasDynamicSharedMemorySize =
false;
992 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
993 hasDynamicSharedMemorySize =
true;
1009 LaunchOp::kNumConfigRegionAttributes + 6, index);
1012 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1014 arg.
ssaName = std::get<0>(ssaValueAndType);
1015 arg.
type = std::get<1>(ssaValueAndType);
1016 regionArguments.push_back(arg);
1027 unsigned numWorkgroupAttrs = regionArguments.size() -
1028 LaunchOp::kNumConfigRegionAttributes -
1029 (hasCluster ? 6 : 0);
1030 result.
addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1047 segmentSizes.front() = asyncDependencies.size();
1050 segmentSizes[7] = 0;
1051 segmentSizes[8] = 0;
1052 segmentSizes[9] = 0;
1054 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1055 result.
addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1069 bool simplified =
false;
1070 auto constPropIdUses = [&](
Value id,
Value size) {
1074 if (
id.getUses().empty())
1081 rewriter.
create<arith::ConstantIndexOp>(op.getLoc(), 0);
1086 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1087 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1088 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1089 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1090 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1091 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1093 return success(simplified);
1105 auto attrName = getNumWorkgroupAttributionsAttrName();
1106 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1107 (*this)->setAttr(attrName,
1109 return getBody().insertArgument(
1110 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1118 return getBody().addArgument(type, loc);
1126 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1130 std::optional<KernelDim3> clusterSize) {
1131 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1132 "expected a symbol reference with a single nested reference");
1140 if (clusterSize.has_value())
1141 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1142 if (dynamicSharedMemorySize)
1147 prop.kernel = kernelSymbol;
1148 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1150 for (
auto &sz : prop.operandSegmentSizes)
1152 prop.operandSegmentSizes[0] = asyncDependencies.size();
1153 if (!clusterSize.has_value()) {
1154 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1155 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1156 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1158 prop.operandSegmentSizes[segmentSizesLen - 3] =
1159 dynamicSharedMemorySize ? 1 : 0;
1160 prop.operandSegmentSizes[segmentSizesLen - 2] =
1161 static_cast<int32_t
>(kernelOperands.size());
1162 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1170 std::optional<KernelDim3> clusterSize) {
1171 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1174 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1175 build(builder, result, kernelSymbol, gridSize,
getBlockSize,
1176 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1177 asyncDependencies, clusterSize);
1184 std::optional<KernelDim3> clusterSize) {
1188 if (clusterSize.has_value())
1189 result.
addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1190 if (dynamicSharedMemorySize)
1196 prop.kernel = kernel;
1197 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1199 for (
auto &sz : prop.operandSegmentSizes)
1201 prop.operandSegmentSizes[0] = 0;
1202 if (!clusterSize.has_value()) {
1203 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1204 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1205 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1207 prop.operandSegmentSizes[segmentSizesLen - 3] =
1208 dynamicSharedMemorySize ? 1 : 0;
1209 prop.operandSegmentSizes[segmentSizesLen - 2] =
1210 static_cast<int32_t
>(kernelOperands.size());
1211 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1214 StringAttr LaunchFuncOp::getKernelModuleName() {
1218 StringAttr LaunchFuncOp::getKernelName() {
1222 unsigned LaunchFuncOp::getNumKernelOperands() {
1223 return getKernelOperands().size();
1226 Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1227 return getKernelOperands()[i];
1230 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1231 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1232 return KernelDim3{operands[0], operands[1], operands[2]};
1235 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1236 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1237 return KernelDim3{operands[3], operands[4], operands[5]};
1240 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1241 assert(hasClusterSize() &&
1242 "cluster size is not set, check hasClusterSize() first");
1243 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1244 return KernelDim3{operands[6], operands[7], operands[8]};
1248 auto module = (*this)->getParentOfType<ModuleOp>();
1250 return emitOpError(
"expected to belong to a module");
1252 if (!module->getAttrOfType<UnitAttr>(
1253 GPUDialect::getContainerModuleAttrName()))
1254 return emitOpError(
"expected the closest surrounding module to have the '" +
1255 GPUDialect::getContainerModuleAttrName() +
1258 if (hasClusterSize()) {
1259 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1261 return emitOpError()
1262 <<
"expects types of the cluster dimensions must be the same";
1270 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1271 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1278 if (clusterValue.has_value()) {
1279 clusterXTy = clusterYTy = clusterZTy = dimTy;
1286 Type clusterYTy,
Type clusterZTy) {
1288 printer <<
": " << dimTy;
1298 auto parseElement = [&]() -> ParseResult {
1299 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1304 parseElement,
" in argument list");
1309 if (operands.empty())
1312 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1313 [&](
const auto &pair) {
1314 auto [operand, type] = pair;
1315 printer << operand <<
" : " << type;
1325 int32_t offset, int32_t width, ShuffleMode mode) {
1326 build(builder, result, value,
1341 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1343 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1354 results.
add(eraseRedundantGpuBarrierOps);
1364 auto attrName = getNumWorkgroupAttributionsAttrName();
1365 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1366 (*this)->setAttr(attrName,
1368 return getBody().insertArgument(
1369 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1377 return getBody().addArgument(type, loc);
1381 StringRef name, FunctionType type,
1391 result.
addAttribute(getNumWorkgroupAttributionsAttrName(),
1398 for (
Type argTy : type.getInputs())
1400 for (
Type argTy : workgroupAttributions)
1402 for (
Type argTy : privateAttributions)
1421 size_t existingArgs = args.size();
1422 ParseResult result =
1428 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1433 attributionAttrs =
nullptr;
1439 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1440 if (!argument.attrs)
1443 attributionAttrsVec.push_back(argument.attrs);
1445 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1461 StringAttr nameAttr;
1468 parser,
false, entryArgs, isVariadic, resultTypes,
1472 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1473 return parser.
emitError(signatureLocation)
1474 <<
"gpu.func requires named arguments";
1481 for (
auto &arg : entryArgs)
1482 argTypes.push_back(arg.
type);
1488 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
1489 getResAttrsAttrName(result.
name));
1494 entryArgs, workgroupAttributionAttrs)))
1499 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1500 result.
addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1502 if (workgroupAttributionAttrs)
1503 result.
addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.
name),
1504 workgroupAttributionAttrs);
1509 entryArgs, privateAttributionAttrs)))
1511 if (privateAttributionAttrs)
1513 privateAttributionAttrs);
1517 result.
addAttribute(GPUDialect::getKernelFuncAttrName(),
1532 ArrayAttr attributes) {
1536 p <<
' ' << keyword <<
'(';
1537 llvm::interleaveComma(
1540 p << v <<
" : " << v.
getType();
1542 size_t attributionIndex = pair.index();
1543 DictionaryAttr attrs;
1544 if (attributes && attributionIndex < attributes.size())
1545 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1556 FunctionType type = getFunctionType();
1562 getWorkgroupAttribAttrs().value_or(
nullptr));
1564 getPrivateAttribAttrs().value_or(
nullptr));
1566 p <<
' ' << getKernelKeyword();
1570 {getNumWorkgroupAttributionsAttrName(),
1571 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1572 getArgAttrsAttrName(), getResAttrsAttrName(),
1573 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1579 StringAttr attrName) {
1580 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1581 if (!allAttrs || index >= allAttrs.size())
1582 return DictionaryAttr();
1583 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1586 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1590 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1595 DictionaryAttr value, StringAttr attrName) {
1597 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1600 elements.append(allAttrs.begin(), allAttrs.end());
1601 while (elements.size() <= index)
1606 elements[index] = value;
1608 op->setAttr(attrName, newValue);
1611 void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1612 DictionaryAttr value) {
1616 void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1617 DictionaryAttr value) {
1622 StringAttr name, StringAttr attrsName) {
1626 return dict.get(name);
1629 Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1631 assert(index < getNumWorkgroupAttributions() &&
1632 "index must map to a workgroup attribution");
1634 getWorkgroupAttribAttrsAttrName());
1637 Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1639 assert(index < getNumPrivateAttributions() &&
1640 "index must map to a private attribution");
1642 getPrivateAttribAttrsAttrName());
1646 Attribute value, StringAttr attrsName) {
1651 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1654 bool mustSort =
true;
1655 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1656 if (elems[i].getName() == name) {
1659 std::swap(elems[i], elems[elems.size() - 1]);
1671 elems.emplace_back(name, value);
1674 DictionaryAttr::sortInPlace(elems);
1676 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1680 void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1682 assert(index < getNumWorkgroupAttributions() &&
1683 "index must map to a workgroup attribution");
1685 getWorkgroupAttribAttrsAttrName());
1688 void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1690 assert(index < getNumPrivateAttributions() &&
1691 "index must map to a private attribution");
1693 getPrivateAttribAttrsAttrName());
1696 LogicalResult GPUFuncOp::verifyType() {
1697 if (isKernel() && getFunctionType().getNumResults() != 0)
1698 return emitOpError() <<
"expected void return type for kernel function";
1704 LogicalResult GPUFuncOp::verifyBody() {
1706 return emitOpError() <<
"expected body with at least one block";
1707 unsigned numFuncArguments = getNumArguments();
1708 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1709 unsigned numBlockArguments = front().getNumArguments();
1710 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1711 return emitOpError() <<
"expected at least "
1712 << numFuncArguments + numWorkgroupAttributions
1713 <<
" arguments to body region";
1716 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1717 Type blockArgType = front().getArgument(i).getType();
1718 if (funcArgTypes[i] != blockArgType)
1719 return emitOpError() <<
"expected body region argument #" << i
1720 <<
" to be of type " << funcArgTypes[i] <<
", got "
1725 GPUDialect::getWorkgroupAddressSpace())) ||
1727 GPUDialect::getPrivateAddressSpace())))
1738 GPUFuncOp
function = (*this)->getParentOfType<GPUFuncOp>();
1740 FunctionType funType =
function.getFunctionType();
1742 if (funType.getNumResults() != getOperands().size())
1743 return emitOpError()
1744 .append(
"expected ", funType.getNumResults(),
" result operands")
1745 .attachNote(
function.getLoc())
1746 .append(
"return type declared here");
1749 llvm::zip(
function.getFunctionType().getResults(), getOperands()))) {
1750 auto [type, operand] = pair.value();
1751 if (type != operand.getType())
1752 return emitOpError() <<
"unexpected type `" << operand.getType()
1753 <<
"' for operand #" << pair.index();
1763 StringRef name, ArrayAttr targets,
1768 props.targets = targets;
1770 props.offloadingHandler = offloadingHandler;
1776 build(builder, result, name,
1777 targets.empty() ? ArrayAttr() : builder.
getArrayAttr(targets),
1781 bool GPUModuleOp::hasTarget(
Attribute target) {
1782 if (ArrayAttr targets = getTargetsAttr())
1783 return llvm::count(targets.getValue(), target);
1788 ArrayAttr &targetsAttr = getProperties().targets;
1794 auto targets = getOperation()->getAttrOfType<ArrayAttr>(
"targets");
1799 for (
auto target : targets) {
1800 if (
auto verifyTargetAttr =
1801 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1802 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1813 Attribute offloadingHandler, ArrayAttr objects) {
1817 properties.objects = objects;
1818 if (offloadingHandler)
1819 properties.offloadingHandler = offloadingHandler;
1821 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1826 build(builder, result, name, offloadingHandler,
1827 objects.empty() ? ArrayAttr() : builder.
getArrayAttr(objects));
1838 if (!offloadingHandler)
1846 printer << '<' << offloadingHandler << '>
';
1849 //===----------------------------------------------------------------------===//
1851 //===----------------------------------------------------------------------===//
1853 LogicalResult MemcpyOp::verify() {
1854 auto srcType = getSrc().getType();
1855 auto dstType = getDst().getType();
1857 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1858 return emitOpError("arguments have incompatible element type");
1860 if (failed(verifyCompatibleShape(srcType, dstType)))
1861 return emitOpError("arguments have incompatible shape");
1870 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1871 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1873 LogicalResult matchAndRewrite(MemcpyOp op,
1874 PatternRewriter &rewriter) const override {
1875 Value dest = op.getDst();
1876 Operation *destDefOp = dest.getDefiningOp();
1877 // `dest` must be defined by an op having Allocate memory effect in order to
1878 // perform the folding.
1880 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1882 // We can erase `op` iff `dest` has no other use apart from its
1883 // use by `op` and dealloc ops.
1884 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1885 return user != op &&
1886 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1889 // We can perform the folding if and only if op has a single async
1890 // dependency and produces an async token as result, or if it does not have
1891 // any async dependency and does not produce any async token result.
1892 if (op.getAsyncDependencies().size() > 1 ||
1893 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1894 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1896 rewriter.replaceOp(op, op.getAsyncDependencies());
1901 } // end anonymous namespace
1903 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1904 MLIRContext *context) {
1905 results.add<EraseTrivialCopyOp>(context);
1908 //===----------------------------------------------------------------------===//
1909 // GPU_SubgroupMmaLoadMatrixOp
1910 //===----------------------------------------------------------------------===//
1912 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1913 auto srcType = getSrcMemref().getType();
1914 auto resType = getRes().getType();
1915 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1916 auto operand = resMatrixType.getOperand();
1917 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1919 if (!srcMemrefType.isLastDimUnitStride())
1921 "expected source memref most minor dim must have unit stride");
1923 if (operand != "AOp" && operand != "BOp" && operand != "COp")
1924 return emitError("only AOp, BOp and COp can be loaded");
1929 //===----------------------------------------------------------------------===//
1930 // GPU_SubgroupMmaStoreMatrixOp
1931 //===----------------------------------------------------------------------===//
1933 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1934 auto srcType = getSrc().getType();
1935 auto dstType = getDstMemref().getType();
1936 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1937 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1939 if (!dstMemrefType.isLastDimUnitStride())
1941 "expected destination memref most minor dim must have unit stride");
1943 if (srcMatrixType.getOperand() != "COp")
1945 "expected the operand matrix being stored to have 'COp
' operand type");
1950 //===----------------------------------------------------------------------===//
1951 // GPU_SubgroupMmaComputeOp
1952 //===----------------------------------------------------------------------===//
1954 LogicalResult SubgroupMmaComputeOp::verify() {
1955 enum OperandMap { A, B, C };
1956 SmallVector<MMAMatrixType, 3> opTypes;
1957 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1958 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1959 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1961 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1962 opTypes[C].getOperand() != "COp")
1963 return emitError("operands must be in the order AOp, BOp, COp");
1965 ArrayRef<int64_t> aShape, bShape, cShape;
1966 aShape = opTypes[A].getShape();
1967 bShape = opTypes[B].getShape();
1968 cShape = opTypes[C].getShape();
1970 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1971 bShape[1] != cShape[1])
1972 return emitError("operand shapes do not satisfy matmul constraints");
1977 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1978 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1979 return memref::foldMemRefCast(*this);
1982 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1983 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1984 return memref::foldMemRefCast(*this);
1987 //===----------------------------------------------------------------------===//
1989 //===----------------------------------------------------------------------===//
1996 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1998 using OpRewritePattern::OpRewritePattern;
2000 LogicalResult matchAndRewrite(WaitOp op,
2001 PatternRewriter &rewriter) const final {
2002 auto predicate = [](Value value) {
2003 auto waitOp = value.getDefiningOp<WaitOp>();
2004 return waitOp && waitOp->getNumOperands() == 0;
2006 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2008 SmallVector<Value> validOperands;
2009 for (Value operand : op->getOperands()) {
2010 if (predicate(operand))
2012 validOperands.push_back(operand);
2014 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2026 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2028 using OpRewritePattern::OpRewritePattern;
2030 LogicalResult matchAndRewrite(WaitOp op,
2031 PatternRewriter &rewriter) const final {
2032 // Erase gpu.wait ops that neither have any async dependencies nor return
2034 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2035 rewriter.eraseOp(op);
2038 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2039 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2040 op.getAsyncToken()) {
2041 rewriter.replaceOp(op, op.getAsyncDependencies());
2044 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2045 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2046 rewriter.eraseOp(op);
2053 } // end anonymous namespace
2055 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2056 MLIRContext *context) {
2057 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2060 //===----------------------------------------------------------------------===//
2062 //===----------------------------------------------------------------------===//
2064 LogicalResult AllocOp::verify() {
2065 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2067 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2068 return emitOpError("dimension operand count does not equal memref "
2069 "dynamic dimension count");
2071 unsigned numSymbols = 0;
2072 if (!memRefType.getLayout().isIdentity())
2073 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2074 if (getSymbolOperands().size() != numSymbols) {
2076 "symbol operand count does not equal memref symbol count");
2086 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2087 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2089 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2090 PatternRewriter &rewriter) const override {
2091 std::optional<int64_t> index = dimOp.getConstantIndex();
2095 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2096 if (!memrefType || index.value() >= memrefType.getRank() ||
2097 !memrefType.isDynamicDim(index.value()))
2100 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2104 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2105 memrefType.getDynamicDimIndex(index.value()));
2106 rewriter.replaceOp(dimOp, substituteOp);
2113 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2114 MLIRContext *context) {
2115 results.add<SimplifyDimOfAllocOp>(context);
2118 //===----------------------------------------------------------------------===//
2119 // GPU object attribute
2120 //===----------------------------------------------------------------------===//
2122 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2123 Attribute target, CompilationTarget format,
2124 StringAttr object, DictionaryAttr properties,
2125 KernelTableAttr kernels) {
2127 return emitError() << "the target attribute cannot be null";
2128 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2130 return emitError() << "the target attribute must implement or promise the "
2131 "`gpu::TargetAttrInterface`";
2135 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2136 StringAttr &object) {
2137 std::optional<CompilationTarget> formatResult;
2138 StringRef enumKeyword;
2139 auto loc = odsParser.getCurrentLocation();
2140 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2141 formatResult = CompilationTarget::Fatbin;
2142 if (!formatResult &&
2144 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2145 odsParser.parseEqual())
2146 return odsParser.emitError(loc, "expected an equal sign");
2148 return odsParser.emitError(loc, "expected keyword for GPU object format");
2149 FailureOr<StringAttr> objectResult =
2150 FieldParser<StringAttr>::parse(odsParser);
2151 if (failed(objectResult))
2152 return odsParser.emitError(odsParser.getCurrentLocation(),
2153 "failed to parse GPU_ObjectAttr parameter "
2154 "'
object' which is to be a `StringAttr`");
2155 format = *formatResult;
2156 object = *objectResult;
2160 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2161 StringAttr object) {
2162 if (format != CompilationTarget::Fatbin)
2163 odsParser << stringifyEnum(format) << " = ";
2164 odsParser << object;
2168 //===----------------------------------------------------------------------===//
2169 // GPU select object attribute
2170 //===----------------------------------------------------------------------===//
2173 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2175 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2177 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2178 if (intAttr.getInt() < 0) {
2179 return emitError() << "the object index must be positive";
2181 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2183 << "the target attribute must be a GPU Target attribute";
2189 //===----------------------------------------------------------------------===//
2190 // DynamicSharedMemoryOp
2191 //===----------------------------------------------------------------------===//
2193 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2194 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2195 return emitOpError() << "must be inside an op with symbol table";
2197 MemRefType memrefType = getResultMemref().getType();
2198 // Check address space
2199 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2200 return emitOpError() << "address space must be "
2201 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2202 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2204 if (memrefType.hasStaticShape()) {
2205 return emitOpError() << "result memref type must be memref<?xi8, "
2206 "#gpu.address_space<workgroup>>";
2211 //===----------------------------------------------------------------------===//
2212 // GPU WarpExecuteOnLane0Op
2213 //===----------------------------------------------------------------------===//
2215 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2216 p << "(" << getLaneid() << ")";
2218 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2219 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2220 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2222 if (!getArgs().empty())
2223 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2224 if (!getResults().empty())
2225 p << " -> (" << getResults().getTypes() << ')
';
2227 p.printRegion(getRegion(),
2228 /*printEntryBlockArgs=*/true,
2229 /*printBlockTerminators=*/!getResults().empty());
2230 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2233 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2234 OperationState &result) {
2235 // Create the region.
2236 result.regions.reserve(1);
2237 Region *warpRegion = result.addRegion();
2239 auto &builder = parser.getBuilder();
2240 OpAsmParser::UnresolvedOperand laneId;
2242 // Parse predicate operand.
2243 if (parser.parseLParen() ||
2244 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2245 parser.parseRParen())
2249 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2250 parser.parseRSquare())
2252 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2253 builder.getContext())),
2254 builder.getI64IntegerAttr(warpSize));
2256 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2259 llvm::SMLoc inputsOperandsLoc;
2260 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2261 SmallVector<Type> inputTypes;
2262 if (succeeded(parser.parseOptionalKeyword("args"))) {
2263 if (parser.parseLParen())
2266 inputsOperandsLoc = parser.getCurrentLocation();
2267 if (parser.parseOperandList(inputsOperands) ||
2268 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2271 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2275 // Parse optional results type list.
2276 if (parser.parseOptionalArrowTypeList(result.types))
2278 // Parse the region.
2279 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2282 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2284 // Parse the optional attribute list.
2285 if (parser.parseOptionalAttrDict(result.attributes))
2290 void WarpExecuteOnLane0Op::getSuccessorRegions(
2291 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2292 if (!point.isParent()) {
2293 regions.push_back(RegionSuccessor(getResults()));
2297 // The warp region is always executed
2298 regions.push_back(RegionSuccessor(&getWarpRegion()));
2301 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2302 TypeRange resultTypes, Value laneId,
2304 build(builder, result, resultTypes, laneId, warpSize,
2305 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2308 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2309 TypeRange resultTypes, Value laneId,
2310 int64_t warpSize, ValueRange args,
2311 TypeRange blockArgTypes) {
2312 result.addOperands(laneId);
2313 result.addAttribute(getAttributeNames()[0],
2314 builder.getI64IntegerAttr(warpSize));
2315 result.addTypes(resultTypes);
2316 result.addOperands(args);
2317 assert(args.size() == blockArgTypes.size());
2318 OpBuilder::InsertionGuard guard(builder);
2319 Region *warpRegion = result.addRegion();
2320 Block *block = builder.createBlock(warpRegion);
2321 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2322 block->addArgument(type, arg.getLoc());
2327 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2328 int64_t warpSize, Operation *op) {
2329 // If the types matches there is no distribution.
2330 if (expanded == distributed)
2332 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2333 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2334 if (!expandedVecType || !distributedVecType)
2335 return op->emitOpError("expected vector type for distributed operands.");
2336 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2337 expandedVecType.getElementType() != distributedVecType.getElementType())
2338 return op->emitOpError(
2339 "expected distributed vectors to have same rank and element type.");
2341 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2342 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2343 int64_t eDim = expandedVecType.getDimSize(i);
2344 int64_t dDim = distributedVecType.getDimSize(i);
2347 if (eDim % dDim != 0)
2348 return op->emitOpError()
2349 << "expected expanded vector dimension #" << i << " (" << eDim
2350 << ") to be a multipler of the distributed vector dimension ("
2352 scales[i] = eDim / dDim;
2354 if (std::accumulate(scales.begin(), scales.end(), 1,
2355 std::multiplies<int64_t>()) != warpSize)
2356 return op->emitOpError()
2357 << "incompatible distribution dimensions from " << expandedVecType
2358 << " to " << distributedVecType << " with warp size = " << warpSize;
2363 LogicalResult WarpExecuteOnLane0Op::verify() {
2364 if (getArgs().size() != getWarpRegion().getNumArguments())
2366 "expected same number op arguments and block arguments.");
2368 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2369 if (yield.getNumOperands() != getNumResults())
2371 "expected same number of yield operands and return values.");
2372 int64_t warpSize = getWarpSize();
2373 for (auto [regionArg, arg] :
2374 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2375 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2376 warpSize, getOperation())))
2379 for (auto [yieldOperand, result] :
2380 llvm::zip_equal(yield.getOperands(), getResults())) {
2381 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2382 warpSize, getOperation())))
2387 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2389 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2392 //===----------------------------------------------------------------------===//
2393 // GPU KernelMetadataAttr
2394 //===----------------------------------------------------------------------===//
2396 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2397 DictionaryAttr metadata) {
2398 assert(kernel && "invalid kernel");
2399 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2400 kernel.getAllArgAttrs(), metadata);
2404 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2405 FunctionOpInterface kernel,
2406 DictionaryAttr metadata) {
2407 assert(kernel && "invalid kernel");
2408 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2409 kernel.getAllArgAttrs(), metadata);
2413 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2416 NamedAttrList attrList;
2417 if (DictionaryAttr dict = getMetadata())
2418 attrList.append(dict);
2419 attrList.append(attrs);
2420 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2421 attrList.getDictionary(getContext()));
2425 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2426 StringAttr name, Type functionType,
2427 ArrayAttr argAttrs, DictionaryAttr metadata) {
2429 return emitError() << "the kernel name can't be empty
";
2431 if (llvm::any_of(argAttrs, [](Attribute attr) {
2432 return !llvm::isa<DictionaryAttr>(attr);
2435 << "all attributes in the array must be a dictionary attribute
";
2440 //===----------------------------------------------------------------------===//
2441 // GPU KernelTableAttr
2442 //===----------------------------------------------------------------------===//
2444 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2445 ArrayRef<KernelMetadataAttr> kernels,
2447 // Note that `is_sorted` is always only invoked once even with assertions ON.
2448 assert((!isSorted || llvm::is_sorted(kernels)) &&
2449 "expected a sorted kernel array
");
2450 // Immediately return the attribute if the array is sorted.
2451 if (isSorted || llvm::is_sorted(kernels))
2452 return Base::get(context, kernels);
2454 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2455 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2456 return Base::get(context, kernelsTmp);
2459 KernelTableAttr KernelTableAttr::getChecked(
2460 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2461 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2462 // Note that `is_sorted` is always only invoked once even with assertions ON.
2463 assert((!isSorted || llvm::is_sorted(kernels)) &&
2464 "expected a sorted kernel array
");
2465 // Immediately return the attribute if the array is sorted.
2466 if (isSorted || llvm::is_sorted(kernels))
2467 return Base::getChecked(emitError, context, kernels);
2469 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2470 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2471 return Base::getChecked(emitError, context, kernelsTmp);
2475 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2476 ArrayRef<KernelMetadataAttr> kernels) {
2477 if (kernels.size() < 2)
2479 // Check that the kernels are uniquely named.
2480 if (std::adjacent_find(kernels.begin(), kernels.end(),
2481 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2482 return l.getName() == r.getName();
2483 }) != kernels.end()) {
2484 return emitError() << "expected all kernels to be uniquely named
";
2489 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2490 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2491 return found ? *iterator : KernelMetadataAttr();
2494 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2495 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2496 return found ? *iterator : KernelMetadataAttr();
2499 //===----------------------------------------------------------------------===//
2500 // GPU target options
2501 //===----------------------------------------------------------------------===//
2503 TargetOptions::TargetOptions(
2504 StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2505 StringRef cmdOptions, StringRef elfSection,
2506 CompilationTarget compilationTarget,
2507 function_ref<SymbolTable *()> getSymbolTableCallback,
2508 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2509 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2510 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2511 function_ref<void(StringRef)> isaCallback)
2512 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2513 cmdOptions, elfSection, compilationTarget,
2514 getSymbolTableCallback, initialLlvmIRCallback,
2515 linkedLlvmIRCallback, optimizedLlvmIRCallback,
2518 TargetOptions::TargetOptions(
2519 TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2520 StringRef cmdOptions, StringRef elfSection,
2521 CompilationTarget compilationTarget,
2522 function_ref<SymbolTable *()> getSymbolTableCallback,
2523 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2524 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2525 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2526 function_ref<void(StringRef)> isaCallback)
2527 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2528 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2529 compilationTarget(compilationTarget),
2530 getSymbolTableCallback(getSymbolTableCallback),
2531 initialLlvmIRCallback(initialLlvmIRCallback),
2532 linkedLlvmIRCallback(linkedLlvmIRCallback),
2533 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2534 isaCallback(isaCallback), typeID(typeID) {}
2536 TypeID TargetOptions::getTypeID() const { return typeID; }
2538 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2540 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2541 return librariesToLink;
2544 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2546 StringRef TargetOptions::getELFSection() const { return elfSection; }
2548 SymbolTable *TargetOptions::getSymbolTable() const {
2549 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2552 function_ref<void(llvm::Module &)>
2553 TargetOptions::getInitialLlvmIRCallback() const {
2554 return initialLlvmIRCallback;
2557 function_ref<void(llvm::Module &)>
2558 TargetOptions::getLinkedLlvmIRCallback() const {
2559 return linkedLlvmIRCallback;
2562 function_ref<void(llvm::Module &)>
2563 TargetOptions::getOptimizedLlvmIRCallback() const {
2564 return optimizedLlvmIRCallback;
2567 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2571 CompilationTarget TargetOptions::getCompilationTarget() const {
2572 return compilationTarget;
2575 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2576 return CompilationTarget::Fatbin;
2579 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2580 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2581 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2582 llvm::StringSaver stringSaver(options.first);
2583 StringRef opts = cmdOptions;
2584 // For a correct tokenization of the command line options `opts` must be
2585 // unquoted, otherwise the tokenization function returns a single string: the
2586 // unquoted `cmdOptions` -which is not the desired behavior.
2587 // Remove any quotes if they are at the beginning and end of the string:
2588 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2589 opts.consume_front("\
""), opts.consume_back(
"\"");
2590 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2591 opts.consume_front(
"'"), opts.consume_back(
"'");
2593 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2596 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2602 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2607 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2609 size_t startPos =
cmdOptions.find(startsWith);
2610 if (startPos == std::string::npos)
2621 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2622 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2624 #define GET_ATTRDEF_CLASSES
2625 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2627 #define GET_OP_CLASSES
2628 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2630 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Kind
An enumeration of the kinds of predicates.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....