36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
50#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
56int64_t GPUBlockMappingAttr::getMappingId()
const {
57 return static_cast<int64_t>(getBlock());
60bool GPUBlockMappingAttr::isLinearMapping()
const {
61 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
64int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
65 return isLinearMapping()
66 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
70int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
71 return static_cast<int64_t>(getWarpgroup());
74bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
75 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
78int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
79 return isLinearMapping()
80 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
84int64_t GPUWarpMappingAttr::getMappingId()
const {
85 return static_cast<int64_t>(getWarp());
88bool GPUWarpMappingAttr::isLinearMapping()
const {
89 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
92int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
93 return isLinearMapping()
94 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
98int64_t GPUThreadMappingAttr::getMappingId()
const {
99 return static_cast<int64_t>(getThread());
102bool GPUThreadMappingAttr::isLinearMapping()
const {
103 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
106int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
107 return isLinearMapping()
108 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
112int64_t GPULaneMappingAttr::getMappingId()
const {
113 return static_cast<int64_t>(getLane());
116bool GPULaneMappingAttr::isLinearMapping()
const {
117 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
120int64_t GPULaneMappingAttr::getRelativeIndex()
const {
121 return isLinearMapping()
122 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
126int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
138Value GPUMappingMaskAttr::createLogicalLinearMappingId(
142 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
143 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
144 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
145 filter = arith::SubIOp::create(
b, loc, filter, one);
146 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
147 return math::CtPopOp::create(
b, loc, filteredId);
160Value GPUMappingMaskAttr::createIsActiveIdPredicate(
164 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
165 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
166 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
167 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
168 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
169 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
173int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
174 return static_cast<int64_t>(getAddressSpace());
177bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
178 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
181int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
182 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
199 elementType, operand);
213 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
222 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
223 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
225 if (
shape.size() != 2)
226 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
230 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
239bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
242 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
243 return gpuAttr.getValue() == getWorkgroupAddressSpace();
247bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
248 Attribute memorySpace = type.getMemorySpace();
249 return isWorkgroupMemoryAddressSpace(memorySpace);
252bool GPUDialect::isConstantMemoryAddressSpace(
Attribute memorySpace) {
255 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
256 return gpuAttr.getValue() == getConstantAddressSpace();
260bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
261 Attribute memorySpace = type.getMemorySpace();
262 return isConstantMemoryAddressSpace(memorySpace);
265bool GPUDialect::isKernel(Operation *op) {
266 if (
auto gpuFunc = dyn_cast<GPUFuncOp>(op))
267 return gpuFunc.isKernel();
268 return static_cast<bool>(
275struct GPUInlinerInterface :
public DialectInlinerInterface {
276 using DialectInlinerInterface::DialectInlinerInterface;
279 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
285void GPUDialect::initialize() {
286 addTypes<AsyncTokenType>();
287 addTypes<MMAMatrixType>();
288 addTypes<SparseDnTensorHandleType>();
289 addTypes<SparseSpMatHandleType>();
290 addTypes<SparseSpGEMMOpHandleType>();
293#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
296#define GET_ATTRDEF_LIST
297#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
299 addInterfaces<GPUInlinerInterface>();
300 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
302 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
303 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
304 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
305 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
306 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
307 declarePromisedInterfaces<memref::IndexedAccessOpInterface,
308 SubgroupMmaLoadMatrixOp,
309 SubgroupMmaStoreMatrixOp>();
315 return "sparse.dntensor_handle";
317 return "sparse.spmat_handle";
319 return "sparse.spgemmop_handle";
321 llvm_unreachable(
"unknown sparse handle kind");
325Type GPUDialect::parseType(DialectAsmParser &parser)
const {
333 if (keyword ==
"async.token")
336 if (keyword ==
"mma_matrix") {
344 SmallVector<int64_t> shape;
365 shape, elementType, operand);
380void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
383 .Case<SparseDnTensorHandleType>([&](Type) {
386 .Case<SparseSpMatHandleType>(
388 .Case<SparseSpGEMMOpHandleType>([&](Type) {
394 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
397 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
399 .DefaultUnreachable(
"unexpected 'gpu' type kind");
404 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
407 " must be a dense i32 array");
408 if (array.size() != 3)
410 " must contain exactly 3 elements");
414LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
415 NamedAttribute attr) {
416 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
418 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
420 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
422 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
423 attr.
getName() != getContainerModuleAttrName())
426 auto module = dyn_cast<ModuleOp>(op);
429 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
430 << ModuleOp::getOperationName() <<
'\'';
444 return parser.
emitError(loc,
"needs to be named when marked 'async'");
459 if (asyncDependencies.empty())
463 printer << llvm::interleaved_array(asyncDependencies);
491 p <<
' ' << keyword <<
'(';
492 llvm::interleaveComma(
493 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
494 BlockArgument v = pair.value();
495 p << v <<
" : " << v.
getType();
497 size_t attributionIndex = pair.index();
498 DictionaryAttr attrs;
499 if (attributes && attributionIndex < attributes.size())
500 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
510 gpu::AddressSpace memorySpace) {
511 for (
Value v : attributions) {
512 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
514 return op->
emitOpError() <<
"expected memref type in attribution";
519 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
522 if (addressSpace.getValue() != memorySpace)
524 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
525 <<
" in attribution";
536 using Kind = gpu::AllReduceOperation;
537 if (llvm::is_contained(
538 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
540 if (!isa<FloatType>(resType))
544 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
545 Kind::AND, Kind::OR, Kind::XOR},
547 if (!isa<IntegerType>(resType))
554LogicalResult gpu::AllReduceOp::verifyRegions() {
555 if (getBody().empty() != getOp().has_value())
556 return emitError(
"expected either an op attribute or a non-empty body");
557 if (!getBody().empty()) {
558 if (getBody().getNumArguments() != 2)
559 return emitError(
"expected two region arguments");
560 for (
auto argument : getBody().getArguments()) {
561 if (argument.getType() !=
getType())
562 return emitError(
"incorrect region argument type");
564 unsigned yieldCount = 0;
565 for (
Block &block : getBody()) {
566 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
567 if (yield.getNumOperands() != 1)
568 return emitError(
"expected one gpu.yield operand");
569 if (yield.getOperand(0).getType() !=
getType())
570 return emitError(
"incorrect gpu.yield type");
575 return emitError(
"expected gpu.yield op in region");
577 gpu::AllReduceOperation opName = *getOp();
579 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
580 <<
"` reduction operation is not compatible with type "
589 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
593 Region &body = launchOp.getBody();
594 assert(!body.
empty() &&
"Invalid region");
600OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
611 AllReduceOperationAttr &attr) {
614 std::optional<AllReduceOperation> op =
615 gpu::symbolizeAllReduceOperation(enumStr);
618 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
624 AllReduceOperationAttr attr) {
633LogicalResult gpu::SubgroupReduceOp::verify() {
635 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
636 if (vecTy.isScalable())
637 return emitOpError() <<
"is not compatible with scalable vector types";
639 elemType = vecTy.getElementType();
642 gpu::AllReduceOperation opName = getOp();
644 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
645 <<
"` reduction operation is not compatible with type "
649 auto clusterSize = getClusterSize();
651 uint32_t size = *clusterSize;
652 if (!llvm::isPowerOf2_32(size)) {
654 <<
" is not a power of two";
658 uint32_t stride = getClusterStride();
659 if (stride != 1 && !clusterSize) {
660 return emitOpError() <<
"cluster stride can only be specified if cluster "
663 if (!llvm::isPowerOf2_32(stride)) {
664 return emitOpError() <<
"cluster stride " << stride
665 <<
" is not a power of two";
671OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
672 if (getClusterSize() == 1)
689 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
693 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
711 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
719 if (!workgroupAttributions.empty())
721 getWorkgroupAttributionsAttrName(
result.name),
725 result.addOperands(asyncDependencies);
730 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
731 getBlockSizeY, getBlockSizeZ});
733 result.addOperands(clusterSizeX);
735 result.addOperands(clusterSizeY);
737 result.addOperands(clusterSizeZ);
738 if (dynamicSharedMemorySize)
739 result.addOperands(dynamicSharedMemorySize);
743 result.addAttribute(getModuleAttrName(
result.name), module);
745 result.addAttribute(getFunctionAttrName(
result.name), function);
753 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
756 for (
Type argTy : workgroupAttributions)
758 for (
Type argTy : privateAttributions)
762 segmentSizes.front() = asyncDependencies.size();
763 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
764 segmentSizes[7] = clusterSizeX ? 1 : 0;
765 segmentSizes[8] = clusterSizeY ? 1 : 0;
766 segmentSizes[9] = clusterSizeZ ? 1 : 0;
767 result.addAttribute(getOperandSegmentSizeAttr(),
772 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
773 auto args = getBody().getArguments();
778 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
779 auto args = getBody().getArguments();
784 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
785 auto args = getBody().getArguments();
790 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
791 auto args = getBody().getArguments();
792 return KernelDim3{args[9], args[10], args[11]};
795std::optional<KernelDim3> LaunchOp::getClusterIds() {
796 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
797 if (!hasClusterSize())
799 auto args = getBody().getArguments();
800 return KernelDim3{args[12], args[13], args[14]};
803std::optional<KernelDim3> LaunchOp::getClusterSize() {
804 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
805 if (!hasClusterSize())
807 auto args = getBody().getArguments();
808 return KernelDim3{args[15], args[16], args[17]};
811KernelDim3 LaunchOp::getGridSizeOperandValues() {
812 auto operands = getOperands().drop_front(getAsyncDependencies().size());
813 return KernelDim3{operands[0], operands[1], operands[2]};
816KernelDim3 LaunchOp::getBlockSizeOperandValues() {
817 auto operands = getOperands().drop_front(getAsyncDependencies().size());
818 return KernelDim3{operands[3], operands[4], operands[5]};
821std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
822 auto operands = getOperands().drop_front(getAsyncDependencies().size());
823 if (!hasClusterSize())
825 return KernelDim3{operands[6], operands[7], operands[8]};
828LogicalResult LaunchOp::verify() {
829 if (!(hasClusterSize()) &&
830 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
831 return emitOpError() <<
"cluster size must be all present";
835LogicalResult LaunchOp::verifyRegions() {
839 if (getBody().empty()) {
842 if (getBody().getNumArguments() <
843 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
844 return emitOpError(
"unexpected number of region arguments");
849 GPUDialect::getWorkgroupAddressSpace())) ||
851 GPUDialect::getPrivateAddressSpace())))
856 for (
Block &block : getBody()) {
859 if (block.back().getNumSuccessors() != 0)
861 if (!isa<gpu::TerminatorOp>(&block.back())) {
864 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
865 "' or a terminator with successors")
866 .attachNote(getLoc())
867 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
871 if (getNumResults() == 0 && getAsyncToken())
872 return emitOpError(
"needs to be named when async keyword is specified");
883 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
884 p << size.
x <<
" = " << operands.
x <<
", ";
885 p << size.
y <<
" = " << operands.
y <<
", ";
886 p << size.
z <<
" = " << operands.
z <<
')';
889void LaunchOp::print(OpAsmPrinter &p) {
890 if (getAsyncToken()) {
892 if (!getAsyncDependencies().empty())
893 p <<
" [" << getAsyncDependencies() <<
']';
896 if (hasClusterSize()) {
897 p <<
' ' << getClustersKeyword();
899 getClusterSizeOperandValues().value(),
900 getClusterIds().value());
902 p <<
' ' << getBlocksKeyword();
905 p <<
' ' << getThreadsKeyword();
908 if (getDynamicSharedMemorySize())
909 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
910 << getDynamicSharedMemorySize();
913 StringRef moduleAttrName = getModuleAttrName();
914 if (
auto module = getModule()) {
915 p <<
' ' << moduleAttrName <<
'(';
920 StringRef functionAttrName = getFunctionAttrName();
921 if (
auto function = getFunction()) {
922 p <<
' ' << functionAttrName <<
'(';
927 if (getCooperative())
937 LaunchOp::getOperandSegmentSizeAttr(),
938 getWorkgroupAttributionsAttrName(),
939 getCooperativeAttrName(), moduleAttrName,
955 assert(
indices.size() == 3 &&
"space for three indices expected");
962 if (args.size() != 3) {
964 << keyword <<
" expects 3 arguments, but got " << args.size();
966 std::move(args.begin(), args.end(),
indices.begin());
968 for (
int i = 0; i < 3; ++i) {
990ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
992 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
993 sizes(LaunchOp::kNumConfigOperands);
996 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
997 LaunchOp::kNumConfigRegionAttributes);
1000 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1001 Type asyncTokenType;
1008 if (!asyncTokenType)
1011 "gpu.launch requires 'async' keyword to return a value");
1012 result.types.push_back(asyncTokenType);
1015 bool hasCluster =
false;
1019 regionArgs.resize(18);
1021 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1022 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1028 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1029 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1037 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword()) ||
1039 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1040 LaunchOp::getBlocksKeyword()) ||
1043 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1044 LaunchOp::getThreadsKeyword()) ||
1049 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1050 bool hasDynamicSharedMemorySize =
false;
1052 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1053 hasDynamicSharedMemorySize =
true;
1062 StringRef moduleAttrName = getModuleAttrName(
result.name);
1064 FlatSymbolRefAttr moduleSymbol;
1072 StringRef functionAttrName = getFunctionAttrName(
result.name);
1074 FlatSymbolRefAttr funcSymbol;
1090 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1091 LaunchOp::kNumConfigRegionAttributes + 6, index);
1093 SmallVector<OpAsmParser::Argument> regionArguments;
1094 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1095 OpAsmParser::Argument arg;
1096 arg.
ssaName = std::get<0>(ssaValueAndType);
1097 arg.
type = std::get<1>(ssaValueAndType);
1098 regionArguments.push_back(arg);
1109 unsigned numWorkgroupAttrs = regionArguments.size() -
1110 LaunchOp::kNumConfigRegionAttributes -
1111 (hasCluster ? 6 : 0);
1112 if (numWorkgroupAttrs != 0)
1113 result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(
result.name),
1124 Region *body =
result.addRegion();
1129 SmallVector<int32_t, 11> segmentSizes(11, 1);
1130 segmentSizes.front() = asyncDependencies.size();
1133 segmentSizes[7] = 0;
1134 segmentSizes[8] = 0;
1135 segmentSizes[9] = 0;
1137 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1138 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1152 bool simplified =
false;
1153 auto constPropIdUses = [&](
Value id,
Value size) {
1157 if (
id.getUses().empty())
1169 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1170 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1171 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1172 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1173 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1174 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1180void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1181 MLIRContext *context) {
1182 rewrites.
add<FoldLaunchArguments>(context);
1187BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1188 int64_t cur = getWorkgroupAttributions().value_or(0);
1189 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1190 return getBody().insertArgument(
1191 getNumConfigRegionAttributes() +
static_cast<unsigned>(cur), type, loc);
1196BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1199 return getBody().addArgument(type, loc);
1206void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1207 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1208 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1209 ValueRange kernelOperands, Type asyncTokenType,
1211 std::optional<KernelDim3> clusterSize) {
1212 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1213 "expected a symbol reference with a single nested reference");
1214 result.addOperands(asyncDependencies);
1221 if (clusterSize.has_value())
1222 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1223 if (dynamicSharedMemorySize)
1224 result.addOperands(dynamicSharedMemorySize);
1225 result.addOperands(kernelOperands);
1227 Properties &prop =
result.getOrAddProperties<Properties>();
1228 prop.kernel = kernelSymbol;
1229 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1231 llvm::fill(prop.operandSegmentSizes, 1);
1232 prop.operandSegmentSizes[0] = asyncDependencies.size();
1233 if (!clusterSize.has_value()) {
1234 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1235 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1236 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1238 prop.operandSegmentSizes[segmentSizesLen - 3] =
1239 dynamicSharedMemorySize ? 1 : 0;
1240 prop.operandSegmentSizes[segmentSizesLen - 2] =
1241 static_cast<int32_t
>(kernelOperands.size());
1242 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1245void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1247 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1248 ValueRange kernelOperands, Type asyncTokenType,
1250 std::optional<KernelDim3> clusterSize) {
1251 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1253 SymbolRefAttr::get(kernelModule.getNameAttr(),
1254 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1255 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1256 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1257 asyncDependencies, clusterSize);
1260void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1262 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1263 ValueRange kernelOperands, Value asyncObject,
1264 std::optional<KernelDim3> clusterSize) {
1268 if (clusterSize.has_value())
1269 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1270 if (dynamicSharedMemorySize)
1271 result.addOperands(dynamicSharedMemorySize);
1272 result.addOperands(kernelOperands);
1274 result.addOperands(asyncObject);
1275 Properties &prop =
result.getOrAddProperties<Properties>();
1276 prop.kernel = kernel;
1277 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1279 llvm::fill(prop.operandSegmentSizes, 1);
1280 prop.operandSegmentSizes[0] = 0;
1281 if (!clusterSize.has_value()) {
1282 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1283 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1284 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1286 prop.operandSegmentSizes[segmentSizesLen - 3] =
1287 dynamicSharedMemorySize ? 1 : 0;
1288 prop.operandSegmentSizes[segmentSizesLen - 2] =
1289 static_cast<int32_t
>(kernelOperands.size());
1290 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1293StringAttr LaunchFuncOp::getKernelModuleName() {
1297StringAttr LaunchFuncOp::getKernelName() {
1301unsigned LaunchFuncOp::getNumKernelOperands() {
1302 return getKernelOperands().size();
1305Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1306 return getKernelOperands()[i];
1309KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1310 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1311 return KernelDim3{operands[0], operands[1], operands[2]};
1314KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1315 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1316 return KernelDim3{operands[3], operands[4], operands[5]};
1319KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1320 assert(hasClusterSize() &&
1321 "cluster size is not set, check hasClusterSize() first");
1322 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1323 return KernelDim3{operands[6], operands[7], operands[8]};
1326LogicalResult LaunchFuncOp::verify() {
1327 auto module = (*this)->getParentOfType<ModuleOp>();
1329 return emitOpError(
"expected to belong to a module");
1331 if (!module->getAttrOfType<UnitAttr>(
1332 GPUDialect::getContainerModuleAttrName()))
1333 return emitOpError(
"expected the closest surrounding module to have the '" +
1334 GPUDialect::getContainerModuleAttrName() +
1337 if (hasClusterSize()) {
1338 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1341 <<
"expects types of the cluster dimensions must be the same";
1344 if (!getAsyncDependencies().empty() && getAsyncObject())
1346 "cannot have both async dependencies and an explicit async object");
1352LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1353 LaunchFuncOp launchOp = *
this;
1356 if (isa<GPUModuleOp>(table))
1361 if (!launchOp->getParentOp() ||
1362 launchOp->getParentOp()->getParentOp() != table)
1367 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1368 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1372 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1373 Operation *kernelContainer =
1375 if (!kernelContainer)
1377 <<
"kernel container '" << kernelContainerName.getValue()
1378 <<
"' is undefined";
1381 if (isa<BinaryOp>(kernelContainer))
1384 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1386 return launchOp.emitOpError()
1387 <<
"kernel module '" << kernelContainerName.getValue()
1388 <<
"' is undefined";
1392 kernelModule, launchOp.getKernelName());
1395 << launchOp.getKernel() <<
"' is undefined";
1396 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1397 if (!kernelConvertedFunction) {
1398 InFlightDiagnostic
diag = launchOp.emitOpError()
1399 <<
"referenced kernel '" << launchOp.getKernel()
1400 <<
"' is not a function";
1401 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
1405 if (!GPUDialect::isKernel(kernelFunc))
1406 return launchOp.emitOpError(
"kernel function is missing the '")
1407 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
1412 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1413 if (!kernelGPUFunction)
1416 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1417 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1418 if (expectedNumArguments != actualNumArguments)
1419 return launchOp.emitOpError(
"got ")
1420 << actualNumArguments <<
" kernel operands but expected "
1421 << expectedNumArguments;
1423 FunctionType functionType = kernelGPUFunction.getFunctionType();
1424 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
1425 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1426 return launchOp.emitOpError(
"type of function argument ")
1427 << i <<
" does not match";
1436 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1437 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1444 if (clusterValue.has_value()) {
1445 clusterXTy = clusterYTy = clusterZTy = dimTy;
1452 Type clusterYTy,
Type clusterZTy) {
1454 printer <<
": " << dimTy;
1464 auto parseElement = [&]() -> ParseResult {
1465 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1470 parseElement,
" in argument list");
1475 if (operands.empty())
1478 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1479 [&](
const auto &pair) {
1480 auto [operand, type] = pair;
1481 printer << operand <<
" : " << type;
1490void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1491 int32_t offset, int32_t width, ShuffleMode mode) {
1492 build(builder,
result, value,
1493 arith::ConstantOp::create(builder,
result.location,
1495 arith::ConstantOp::create(builder,
result.location,
1504LogicalResult RotateOp::verify() {
1505 uint32_t offset = getOffset();
1506 uint32_t width = getWidth();
1508 if (offset >= width) {
1509 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1522 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1526 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1527 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1531 if (!nextMemfence) {
1532 op.removeAddressSpacesAttr();
1536 if (*thisMemfence == *nextMemfence) {
1540 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1542 mergedSpaces.insert(attr);
1544 mergedSpaces.insert(attr);
1545 op.setAddressSpacesAttr(rewriter.
getArrayAttr(mergedSpaces.takeVector()));
1553void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1554 MLIRContext *context) {
1558void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1559 mlir::OperationState &odsState,
1560 std::optional<AddressSpace> addressSpace) {
1564 AddressSpaceAttr::get(odsBuilder.
getContext(), addressSpace.value()));
1565 build(odsBuilder, odsState, addressSpacesAttr);
1572void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1573 Value memrefToFence) {
1574 std::optional<AddressSpace> addrSpaceToFence;
1575 if (
auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.
getType()))
1576 if (
auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1577 memrefType.getMemorySpace()))
1578 addrSpaceToFence = addrSpaceAttr.getValue();
1579 return build(builder, odsState, addrSpaceToFence);
1588BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1589 int64_t cur = getWorkgroupAttributions().value_or(0);
1590 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1591 return getBody().insertArgument(
1592 getFunctionType().getNumInputs() +
static_cast<unsigned>(cur), type, loc);
1597BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1600 return getBody().addArgument(type, loc);
1603void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1604 StringRef name, FunctionType type,
1607 ArrayRef<NamedAttribute> attrs) {
1608 OpBuilder::InsertionGuard g(builder);
1612 result.addAttribute(getFunctionTypeAttrName(
result.name),
1613 TypeAttr::get(type));
1614 result.addAttribute(getWorkgroupAttributionsAttrName(
result.name),
1616 result.addAttributes(attrs);
1617 Region *body =
result.addRegion();
1621 for (Type argTy : type.getInputs())
1623 for (Type argTy : workgroupAttributions)
1625 for (Type argTy : privateAttributions)
1644 size_t existingArgs = args.size();
1651 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1656 attributionAttrs =
nullptr;
1662 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1663 if (!argument.attrs)
1666 attributionAttrsVec.push_back(argument.attrs);
1668 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1677ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1678 SmallVector<OpAsmParser::Argument> entryArgs;
1679 SmallVector<DictionaryAttr> resultAttrs;
1680 SmallVector<Type> resultTypes;
1684 StringAttr nameAttr;
1691 parser,
false, entryArgs, isVariadic, resultTypes,
1695 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1696 return parser.
emitError(signatureLocation)
1697 <<
"gpu.func requires named arguments";
1703 SmallVector<Type> argTypes;
1704 for (
auto &arg : entryArgs)
1705 argTypes.push_back(arg.
type);
1707 result.addAttribute(getFunctionTypeAttrName(
result.name),
1708 TypeAttr::get(type));
1711 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1712 getResAttrsAttrName(
result.name));
1714 Attribute workgroupAttributionAttrs;
1717 entryArgs, workgroupAttributionAttrs)))
1722 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1723 if (numWorkgroupAttrs != 0)
1725 GPUFuncOp::getWorkgroupAttributionsAttrName(
result.name),
1727 if (workgroupAttributionAttrs)
1728 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1729 workgroupAttributionAttrs);
1731 Attribute privateAttributionAttrs;
1734 entryArgs, privateAttributionAttrs)))
1736 if (privateAttributionAttrs)
1737 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1738 privateAttributionAttrs);
1742 result.addAttribute(GPUFuncOp::getKernelAttrName(
result.name),
1751 auto *body =
result.addRegion();
1755void GPUFuncOp::print(OpAsmPrinter &p) {
1759 FunctionType type = getFunctionType();
1765 getWorkgroupAttribAttrs().value_or(
nullptr));
1767 getPrivateAttribAttrs().value_or(
nullptr));
1769 p <<
' ' << getKernelKeyword();
1773 {getWorkgroupAttributionsAttrName(), getKernelAttrName(),
1774 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1775 getArgAttrsAttrName(), getResAttrsAttrName(),
1776 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1782 StringAttr attrName) {
1783 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1784 if (!allAttrs ||
index >= allAttrs.size())
1785 return DictionaryAttr();
1786 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1789DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1793DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1798 DictionaryAttr value, StringAttr attrName) {
1800 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1803 elements.append(allAttrs.begin(), allAttrs.end());
1804 while (elements.size() <=
index)
1805 elements.push_back(DictionaryAttr::get(ctx));
1807 elements[
index] = DictionaryAttr::get(ctx);
1809 elements[
index] = value;
1810 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1811 op->setAttr(attrName, newValue);
1814void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1815 DictionaryAttr value) {
1819void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1820 DictionaryAttr value) {
1825 StringAttr name, StringAttr attrsName) {
1829 return dict.get(name);
1832Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1834 assert(index < getNumWorkgroupAttributions() &&
1835 "index must map to a workgroup attribution");
1837 getWorkgroupAttribAttrsAttrName());
1840Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1842 assert(index < getNumPrivateAttributions() &&
1843 "index must map to a private attribution");
1845 getPrivateAttribAttrsAttrName());
1849 Attribute value, StringAttr attrsName) {
1854 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1857 bool mustSort =
true;
1858 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1859 if (elems[i].getName() == name) {
1862 std::swap(elems[i], elems[elems.size() - 1]);
1874 elems.emplace_back(name, value);
1877 DictionaryAttr::sortInPlace(elems);
1879 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1883void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1885 assert(index < getNumWorkgroupAttributions() &&
1886 "index must map to a workgroup attribution");
1888 getWorkgroupAttribAttrsAttrName());
1891void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1893 assert(index < getNumPrivateAttributions() &&
1894 "index must map to a private attribution");
1896 getPrivateAttribAttrsAttrName());
1899LogicalResult GPUFuncOp::verifyType() {
1900 if (isKernel() && getFunctionType().getNumResults() != 0)
1901 return emitOpError() <<
"expected void return type for kernel function";
1907LogicalResult GPUFuncOp::verifyBody() {
1909 return emitOpError() <<
"expected body with at least one block";
1910 unsigned numFuncArguments = getNumArguments();
1911 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1912 unsigned numBlockArguments = front().getNumArguments();
1913 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1915 << numFuncArguments + numWorkgroupAttributions
1916 <<
" arguments to body region";
1918 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1919 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1920 Type blockArgType = front().getArgument(i).getType();
1921 if (funcArgTypes[i] != blockArgType)
1922 return emitOpError() <<
"expected body region argument #" << i
1923 <<
" to be of type " << funcArgTypes[i] <<
", got "
1928 GPUDialect::getWorkgroupAddressSpace())) ||
1930 GPUDialect::getPrivateAddressSpace())))
1940LogicalResult gpu::ReturnOp::verify() {
1941 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1943 FunctionType funType = function.getFunctionType();
1945 if (funType.getNumResults() != getOperands().size())
1947 .append(
"expected ", funType.getNumResults(),
" result operands")
1948 .attachNote(function.getLoc())
1949 .append(
"return type declared here");
1951 for (
const auto &pair : llvm::enumerate(
1952 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1953 auto [type, operand] = pair.value();
1954 if (type != operand.getType())
1955 return emitOpError() <<
"unexpected type `" << operand.getType()
1956 <<
"' for operand #" << pair.index();
1965void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1967 Attribute offloadingHandler) {
1968 result.addRegion()->emplaceBlock();
1969 Properties &props =
result.getOrAddProperties<Properties>();
1971 props.targets = targets;
1973 props.offloadingHandler = offloadingHandler;
1976void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1977 StringRef name, ArrayRef<Attribute> targets,
1978 Attribute offloadingHandler) {
1979 build(builder,
result, name,
1984bool GPUModuleOp::hasTarget(Attribute
target) {
1985 if (
ArrayAttr targets = getTargetsAttr())
1986 return llvm::count(targets.getValue(),
target);
1990void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1991 ArrayAttr &targetsAttr = getProperties().targets;
1992 SmallVector<Attribute> targetsVector(targets);
1993 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
1996LogicalResult GPUModuleOp::verify() {
1997 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
2002 for (
auto target : targets) {
2003 if (
auto verifyTargetAttr =
2004 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
2005 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
2015void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
2016 Attribute offloadingHandler,
ArrayAttr objects) {
2017 auto &properties =
result.getOrAddProperties<Properties>();
2020 properties.objects = objects;
2021 if (offloadingHandler)
2022 properties.offloadingHandler = offloadingHandler;
2024 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
2027void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
2028 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2029 build(builder,
result, name, offloadingHandler,
2041 if (!offloadingHandler)
2048 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
2049 printer <<
'<' << offloadingHandler <<
'>';
2056LogicalResult MemcpyOp::verify() {
2057 auto srcType = getSrc().getType();
2058 auto dstType = getDst().getType();
2061 return emitOpError(
"arguments have incompatible element type");
2064 return emitOpError(
"arguments have incompatible shape");
2073struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
2074 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2076 LogicalResult matchAndRewrite(MemcpyOp op,
2077 PatternRewriter &rewriter)
const override {
2078 Value dest = op.getDst();
2087 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
2088 return user != op &&
2089 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2095 if (op.getAsyncDependencies().size() > 1 ||
2096 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2097 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2099 rewriter.
replaceOp(op, op.getAsyncDependencies());
2106void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2107 MLIRContext *context) {
2108 results.
add<EraseTrivialCopyOp>(context);
2115LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2116 auto srcType = getSrcMemref().getType();
2117 auto resType = getRes().getType();
2118 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2119 auto operand = resMatrixType.getOperand();
2120 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2122 if (!srcMemrefType.isLastDimUnitStride())
2124 "expected source memref most minor dim must have unit stride");
2126 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2127 return emitError(
"only AOp, BOp and COp can be loaded");
2136LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2137 auto srcType = getSrc().getType();
2138 auto dstType = getDstMemref().getType();
2139 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2140 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2142 if (!dstMemrefType.isLastDimUnitStride())
2144 "expected destination memref most minor dim must have unit stride");
2146 if (srcMatrixType.getOperand() !=
"COp")
2148 "expected the operand matrix being stored to have 'COp' operand type");
2157LogicalResult SubgroupMmaComputeOp::verify() {
2158 enum OperandMap {
A,
B,
C };
2159 SmallVector<MMAMatrixType, 3> opTypes;
2160 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2161 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2162 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2164 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2165 opTypes[C].getOperand() !=
"COp")
2166 return emitError(
"operands must be in the order AOp, BOp, COp");
2168 ArrayRef<int64_t> aShape, bShape, cShape;
2169 aShape = opTypes[
A].getShape();
2170 bShape = opTypes[
B].getShape();
2171 cShape = opTypes[
C].getShape();
2173 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2174 bShape[1] != cShape[1])
2175 return emitError(
"operand shapes do not satisfy matmul constraints");
2180LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2181 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2185LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2186 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2199struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2203 LogicalResult matchAndRewrite(WaitOp op,
2204 PatternRewriter &rewriter)
const final {
2205 auto predicate = [](Value value) {
2206 auto waitOp = value.getDefiningOp<WaitOp>();
2207 return waitOp && waitOp->getNumOperands() == 0;
2209 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2211 SmallVector<Value> validOperands;
2212 for (Value operand : op->getOperands()) {
2213 if (predicate(operand))
2215 validOperands.push_back(operand);
2217 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2229struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2233 LogicalResult matchAndRewrite(WaitOp op,
2234 PatternRewriter &rewriter)
const final {
2237 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2242 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2243 op.getAsyncToken()) {
2244 rewriter.
replaceOp(op, op.getAsyncDependencies());
2248 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2258void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2259 MLIRContext *context) {
2260 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2267LogicalResult AllocOp::verify() {
2268 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2274 unsigned numSymbols = 0;
2275 if (!memRefType.getLayout().isIdentity())
2276 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2277 if (getSymbolOperands().size() != numSymbols) {
2279 "symbol operand count does not equal memref symbol count");
2289struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2290 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2292 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2293 PatternRewriter &rewriter)
const override {
2294 std::optional<int64_t> index = dimOp.getConstantIndex();
2298 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2299 if (!memrefType || index.value() >= memrefType.getRank() ||
2300 !memrefType.isDynamicDim(index.value()))
2303 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2307 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2308 memrefType.getDynamicDimIndex(index.value()));
2309 rewriter.
replaceOp(dimOp, substituteOp);
2316void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2317 MLIRContext *context) {
2318 results.
add<SimplifyDimOfAllocOp>(context);
2326 Attribute
target, CompilationTarget format,
2327 StringAttr
object, DictionaryAttr properties,
2328 KernelTableAttr kernels) {
2330 return emitError() <<
"the target attribute cannot be null";
2331 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2333 return emitError() <<
"the target attribute must implement or promise the "
2334 "`gpu::TargetAttrInterface`";
2338ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2339 StringAttr &
object) {
2340 std::optional<CompilationTarget> formatResult;
2341 StringRef enumKeyword;
2344 formatResult = CompilationTarget::Fatbin;
2345 if (!formatResult &&
2347 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2349 return odsParser.
emitError(loc,
"expected an equal sign");
2351 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2352 FailureOr<StringAttr> objectResult =
2353 FieldParser<StringAttr>::parse(odsParser);
2354 if (
failed(objectResult))
2356 "failed to parse GPU_ObjectAttr parameter "
2357 "'object' which is to be a `StringAttr`");
2358 format = *formatResult;
2359 object = *objectResult;
2363void printObject(AsmPrinter &odsParser, CompilationTarget format,
2364 StringAttr
object) {
2365 if (format != CompilationTarget::Fatbin)
2366 odsParser << stringifyEnum(format) <<
" = ";
2367 odsParser << object;
2380 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2381 if (intAttr.getInt() < 0) {
2382 return emitError() <<
"the object index must be positive";
2384 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2386 <<
"the target attribute must be a GPU Target attribute";
2396LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2397 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2398 return emitOpError() <<
"must be inside an op with symbol table";
2400 MemRefType memrefType = getResultMemref().getType();
2402 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2404 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2405 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2407 if (memrefType.hasStaticShape()) {
2408 return emitOpError() <<
"result memref type must be memref<?xi8, "
2409 "#gpu.address_space<workgroup>>";
2418void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2419 p <<
"(" << getLaneid() <<
")";
2421 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2422 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2423 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2425 if (!getArgs().empty())
2426 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2427 if (!getResults().empty())
2428 p <<
" -> (" << getResults().getTypes() <<
')';
2432 !getResults().empty());
2436ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2437 OperationState &
result) {
2439 result.regions.reserve(1);
2440 Region *warpRegion =
result.addRegion();
2443 OpAsmParser::UnresolvedOperand laneId;
2455 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2462 llvm::SMLoc inputsOperandsLoc;
2463 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2464 SmallVector<Type> inputTypes;
2474 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2485 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2493void WarpExecuteOnLane0Op::getSuccessorRegions(
2494 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2501 regions.push_back(RegionSuccessor(&getWarpRegion()));
2504ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2507void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2510 build(builder,
result, resultTypes, laneId, warpSize,
2514void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2518 result.addOperands(laneId);
2519 result.addAttribute(getAttributeNames()[0],
2521 result.addTypes(resultTypes);
2522 result.addOperands(args);
2523 assert(args.size() == blockArgTypes.size());
2524 OpBuilder::InsertionGuard guard(builder);
2525 Region *warpRegion =
result.addRegion();
2527 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2536 if (expanded == distributed)
2538 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2539 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2540 if (!expandedVecType || !distributedVecType)
2541 return op->
emitOpError(
"expected vector type for distributed operands.");
2542 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2543 expandedVecType.getElementType() != distributedVecType.getElementType())
2545 "expected distributed vectors to have same rank and element type.");
2548 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2549 int64_t eDim = expandedVecType.getDimSize(i);
2550 int64_t dDim = distributedVecType.getDimSize(i);
2553 if (eDim % dDim != 0)
2555 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2556 <<
") to be a multipler of the distributed vector dimension ("
2558 scales[i] = eDim / dDim;
2560 if (llvm::product_of(scales) != warpSize)
2562 <<
"incompatible distribution dimensions from " << expandedVecType
2563 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2568LogicalResult WarpExecuteOnLane0Op::verify() {
2569 if (getArgs().size() != getWarpRegion().getNumArguments())
2571 "expected same number op arguments and block arguments.");
2572 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2574 return emitOpError(
"expected body to be terminated with 'gpu.yield'");
2575 if (yield.getNumOperands() != getNumResults())
2577 "expected same number of yield operands and return values.");
2578 int64_t warpSize = getWarpSize();
2579 for (
auto [regionArg, arg] :
2580 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2582 warpSize, getOperation())))
2585 for (
auto [yieldOperand,
result] :
2586 llvm::zip_equal(yield.getOperands(), getResults())) {
2588 warpSize, getOperation())))
2593bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2598gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2599 return cast<gpu::YieldOp>(getBody()->getTerminator());
2606void gpu::SubgroupBroadcastOp::inferResultRanges(
2607 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2608 setResultRange(getResult(), argRanges.front());
2612 switch (getBroadcastType()) {
2613 case BroadcastType::first_active_lane:
2617 case BroadcastType::specific_lane:
2621 llvm_unreachable(
"Unknown BroadcastType");
2624LogicalResult gpu::SubgroupBroadcastOp::verify() {
2625 switch (getBroadcastType()) {
2626 case BroadcastType::first_active_lane:
2629 <<
"lane can only be specified for `specific_lane` broadcast";
2631 case BroadcastType::specific_lane:
2634 <<
"lane must be specified for `specific_lane` broadcast";
2637 llvm_unreachable(
"Unknown BroadcastType");
2640OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2642 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2643 return prev.getResult();
2658KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2659 DictionaryAttr metadata) {
2660 assert(kernel &&
"invalid kernel");
2661 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2662 kernel.getAllArgAttrs(), metadata);
2667 FunctionOpInterface kernel,
2668 DictionaryAttr metadata) {
2669 assert(kernel &&
"invalid kernel");
2671 kernel.getAllArgAttrs(), metadata);
2675KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2678 NamedAttrList attrList;
2679 if (DictionaryAttr dict = getMetadata())
2682 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2688 StringAttr name, Type functionType,
2689 ArrayAttr argAttrs, DictionaryAttr metadata) {
2691 return emitError() <<
"the kernel name can't be empty";
2693 if (llvm::any_of(argAttrs, [](Attribute attr) {
2694 return !llvm::isa<DictionaryAttr>(attr);
2697 <<
"all attributes in the array must be a dictionary attribute";
2706KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2707 ArrayRef<KernelMetadataAttr> kernels,
2710 assert((!isSorted || llvm::is_sorted(kernels)) &&
2711 "expected a sorted kernel array");
2713 if (isSorted || llvm::is_sorted(kernels))
2714 return Base::get(context, kernels);
2716 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2717 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2718 return Base::get(context, kernelsTmp);
2721KernelTableAttr KernelTableAttr::getChecked(
2723 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2725 assert((!isSorted || llvm::is_sorted(kernels)) &&
2726 "expected a sorted kernel array");
2728 if (isSorted || llvm::is_sorted(kernels))
2729 return Base::getChecked(
emitError, context, kernels);
2731 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2732 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2733 return Base::getChecked(
emitError, context, kernelsTmp);
2738 ArrayRef<KernelMetadataAttr> kernels) {
2739 if (kernels.size() < 2)
2742 if (std::adjacent_find(kernels.begin(), kernels.end(),
2743 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2744 return l.getName() == r.getName();
2745 }) != kernels.end()) {
2746 return emitError() <<
"expected all kernels to be uniquely named";
2751KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2753 return found ? *iterator : KernelMetadataAttr();
2756KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2758 return found ? *iterator : KernelMetadataAttr();
2838 return CompilationTarget::Fatbin;
2841std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2843 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2844 llvm::StringSaver stringSaver(
options.first);
2850 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2851 opts.consume_front(
"\""), opts.consume_back(
"\"");
2852 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2853 opts.consume_front(
"'"), opts.consume_back(
"'");
2855 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2858 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2864std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2869std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2871 size_t startPos =
cmdOptions.find(startsWith);
2872 if (startPos == std::string::npos)
2883#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2884#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2886#define GET_ATTRDEF_CLASSES
2887#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2889#define GET_OP_CLASSES
2890#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2892#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....