36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
50#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
56int64_t GPUBlockMappingAttr::getMappingId()
const {
57 return static_cast<int64_t>(getBlock());
60bool GPUBlockMappingAttr::isLinearMapping()
const {
61 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
64int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
65 return isLinearMapping()
66 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
70int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
71 return static_cast<int64_t>(getWarpgroup());
74bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
75 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
78int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
79 return isLinearMapping()
80 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
84int64_t GPUWarpMappingAttr::getMappingId()
const {
85 return static_cast<int64_t>(getWarp());
88bool GPUWarpMappingAttr::isLinearMapping()
const {
89 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
92int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
93 return isLinearMapping()
94 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
98int64_t GPUThreadMappingAttr::getMappingId()
const {
99 return static_cast<int64_t>(getThread());
102bool GPUThreadMappingAttr::isLinearMapping()
const {
103 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
106int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
107 return isLinearMapping()
108 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
112int64_t GPULaneMappingAttr::getMappingId()
const {
113 return static_cast<int64_t>(getLane());
116bool GPULaneMappingAttr::isLinearMapping()
const {
117 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
120int64_t GPULaneMappingAttr::getRelativeIndex()
const {
121 return isLinearMapping()
122 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
126int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
138Value GPUMappingMaskAttr::createLogicalLinearMappingId(
142 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
143 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
144 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
145 filter = arith::SubIOp::create(
b, loc, filter, one);
146 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
147 return math::CtPopOp::create(
b, loc, filteredId);
160Value GPUMappingMaskAttr::createIsActiveIdPredicate(
164 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
165 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
166 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
167 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
168 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
169 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
173int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
174 return static_cast<int64_t>(getAddressSpace());
177bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
178 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
181int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
182 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
199 elementType, operand);
213 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
222 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
223 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
225 if (
shape.size() != 2)
226 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
230 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
239bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
242 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
243 return gpuAttr.getValue() == getWorkgroupAddressSpace();
247bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
248 Attribute memorySpace = type.getMemorySpace();
249 return isWorkgroupMemoryAddressSpace(memorySpace);
252bool GPUDialect::isConstantMemoryAddressSpace(
Attribute memorySpace) {
255 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
256 return gpuAttr.getValue() == getConstantAddressSpace();
260bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
261 Attribute memorySpace = type.getMemorySpace();
262 return isConstantMemoryAddressSpace(memorySpace);
265bool GPUDialect::isKernel(Operation *op) {
266 if (
auto gpuFunc = dyn_cast<GPUFuncOp>(op))
267 return gpuFunc.isKernel();
268 return static_cast<bool>(
275struct GPUInlinerInterface :
public DialectInlinerInterface {
276 using DialectInlinerInterface::DialectInlinerInterface;
279 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
285void GPUDialect::initialize() {
286 addTypes<AsyncTokenType>();
287 addTypes<MMAMatrixType>();
288 addTypes<SparseDnTensorHandleType>();
289 addTypes<SparseSpMatHandleType>();
290 addTypes<SparseSpGEMMOpHandleType>();
293#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
296#define GET_ATTRDEF_LIST
297#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
299 addInterfaces<GPUInlinerInterface>();
300 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
302 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
303 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
304 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
305 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
306 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
307 declarePromisedInterfaces<memref::IndexedAccessOpInterface,
308 SubgroupMmaLoadMatrixOp,
309 SubgroupMmaStoreMatrixOp>();
315 return "sparse.dntensor_handle";
317 return "sparse.spmat_handle";
319 return "sparse.spgemmop_handle";
321 llvm_unreachable(
"unknown sparse handle kind");
325Type GPUDialect::parseType(DialectAsmParser &parser)
const {
333 if (keyword ==
"async.token")
336 if (keyword ==
"mma_matrix") {
344 SmallVector<int64_t> shape;
365 shape, elementType, operand);
380void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
383 .Case<SparseDnTensorHandleType>([&](Type) {
386 .Case<SparseSpMatHandleType>(
388 .Case<SparseSpGEMMOpHandleType>([&](Type) {
394 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
397 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
399 .DefaultUnreachable(
"unexpected 'gpu' type kind");
404 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
407 " must be a dense i32 array");
408 if (array.size() != 3)
410 " must contain exactly 3 elements");
414LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
415 NamedAttribute attr) {
416 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
418 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
420 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
422 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
423 attr.
getName() != getContainerModuleAttrName())
426 auto module = dyn_cast<ModuleOp>(op);
429 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
430 << ModuleOp::getOperationName() <<
'\'';
444 return parser.
emitError(loc,
"needs to be named when marked 'async'");
459 if (asyncDependencies.empty())
463 printer << llvm::interleaved_array(asyncDependencies);
491 p <<
' ' << keyword <<
'(';
492 llvm::interleaveComma(
493 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
494 BlockArgument v = pair.value();
495 p << v <<
" : " << v.
getType();
497 size_t attributionIndex = pair.index();
498 DictionaryAttr attrs;
499 if (attributes && attributionIndex < attributes.size())
500 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
510 gpu::AddressSpace memorySpace) {
511 for (
Value v : attributions) {
512 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
514 return op->
emitOpError() <<
"expected memref type in attribution";
519 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
522 if (addressSpace.getValue() != memorySpace)
524 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
525 <<
" in attribution";
536 using Kind = gpu::AllReduceOperation;
537 if (llvm::is_contained(
538 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
540 if (!isa<FloatType>(resType))
544 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
545 Kind::AND, Kind::OR, Kind::XOR},
547 if (!isa<IntegerType>(resType))
554LogicalResult gpu::AllReduceOp::verifyRegions() {
555 if (getBody().empty() != getOp().has_value())
556 return emitError(
"expected either an op attribute or a non-empty body");
557 if (!getBody().empty()) {
558 if (getBody().getNumArguments() != 2)
559 return emitError(
"expected two region arguments");
560 for (
auto argument : getBody().getArguments()) {
561 if (argument.getType() !=
getType())
562 return emitError(
"incorrect region argument type");
564 unsigned yieldCount = 0;
565 for (
Block &block : getBody()) {
566 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
567 if (yield.getNumOperands() != 1)
568 return emitError(
"expected one gpu.yield operand");
569 if (yield.getOperand(0).getType() !=
getType())
570 return emitError(
"incorrect gpu.yield type");
575 return emitError(
"expected gpu.yield op in region");
577 gpu::AllReduceOperation opName = *getOp();
579 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
580 <<
"` reduction operation is not compatible with type "
589 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
593 Region &body = launchOp.getBody();
594 assert(!body.
empty() &&
"Invalid region");
600OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
611 AllReduceOperationAttr &attr) {
614 std::optional<AllReduceOperation> op =
615 gpu::symbolizeAllReduceOperation(enumStr);
618 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
624 AllReduceOperationAttr attr) {
633LogicalResult gpu::SubgroupReduceOp::verify() {
635 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
636 if (vecTy.isScalable())
637 return emitOpError() <<
"is not compatible with scalable vector types";
639 elemType = vecTy.getElementType();
642 gpu::AllReduceOperation opName = getOp();
644 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
645 <<
"` reduction operation is not compatible with type "
649 auto clusterSize = getClusterSize();
651 uint32_t size = *clusterSize;
652 if (!llvm::isPowerOf2_32(size)) {
654 <<
" is not a power of two";
658 uint32_t stride = getClusterStride();
659 if (stride != 1 && !clusterSize) {
660 return emitOpError() <<
"cluster stride can only be specified if cluster "
663 if (!llvm::isPowerOf2_32(stride)) {
664 return emitOpError() <<
"cluster stride " << stride
665 <<
" is not a power of two";
671OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
672 if (getClusterSize() == 1)
689 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
693 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
711 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
719 if (!workgroupAttributions.empty())
721 getWorkgroupAttributionsAttrName(
result.name),
725 result.addOperands(asyncDependencies);
730 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
731 getBlockSizeY, getBlockSizeZ});
733 result.addOperands(clusterSizeX);
735 result.addOperands(clusterSizeY);
737 result.addOperands(clusterSizeZ);
738 if (dynamicSharedMemorySize)
739 result.addOperands(dynamicSharedMemorySize);
743 result.addAttribute(getModuleAttrName(
result.name), module);
745 result.addAttribute(getFunctionAttrName(
result.name), function);
753 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
756 for (
Type argTy : workgroupAttributions)
758 for (
Type argTy : privateAttributions)
762 segmentSizes.front() = asyncDependencies.size();
763 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
764 segmentSizes[7] = clusterSizeX ? 1 : 0;
765 segmentSizes[8] = clusterSizeY ? 1 : 0;
766 segmentSizes[9] = clusterSizeZ ? 1 : 0;
767 result.addAttribute(getOperandSegmentSizeAttr(),
772 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
773 auto args = getBody().getArguments();
778 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
779 auto args = getBody().getArguments();
784 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
785 auto args = getBody().getArguments();
790 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
791 auto args = getBody().getArguments();
792 return KernelDim3{args[9], args[10], args[11]};
795std::optional<KernelDim3> LaunchOp::getClusterIds() {
796 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
797 if (!hasClusterSize())
799 auto args = getBody().getArguments();
800 return KernelDim3{args[12], args[13], args[14]};
803std::optional<KernelDim3> LaunchOp::getClusterSize() {
804 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
805 if (!hasClusterSize())
807 auto args = getBody().getArguments();
808 return KernelDim3{args[15], args[16], args[17]};
811KernelDim3 LaunchOp::getGridSizeOperandValues() {
812 auto operands = getOperands().drop_front(getAsyncDependencies().size());
813 return KernelDim3{operands[0], operands[1], operands[2]};
816KernelDim3 LaunchOp::getBlockSizeOperandValues() {
817 auto operands = getOperands().drop_front(getAsyncDependencies().size());
818 return KernelDim3{operands[3], operands[4], operands[5]};
821std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
822 auto operands = getOperands().drop_front(getAsyncDependencies().size());
823 if (!hasClusterSize())
825 return KernelDim3{operands[6], operands[7], operands[8]};
828LogicalResult LaunchOp::verify() {
829 if (!(hasClusterSize()) &&
830 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
831 return emitOpError() <<
"cluster size must be all present";
835LogicalResult LaunchOp::verifyRegions() {
839 if (getBody().empty()) {
842 if (getBody().getNumArguments() <
843 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
844 return emitOpError(
"unexpected number of region arguments");
849 GPUDialect::getWorkgroupAddressSpace())) ||
851 GPUDialect::getPrivateAddressSpace())))
856 for (
Block &block : getBody()) {
859 if (block.back().getNumSuccessors() != 0)
861 if (!isa<gpu::TerminatorOp>(&block.back())) {
864 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
865 "' or a terminator with successors")
866 .attachNote(getLoc())
867 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
871 if (getNumResults() == 0 && getAsyncToken())
872 return emitOpError(
"needs to be named when async keyword is specified");
883 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
884 p << size.
x <<
" = " << operands.
x <<
", ";
885 p << size.
y <<
" = " << operands.
y <<
", ";
886 p << size.
z <<
" = " << operands.
z <<
')';
889void LaunchOp::print(OpAsmPrinter &p) {
890 if (getAsyncToken()) {
892 if (!getAsyncDependencies().empty())
893 p <<
" [" << getAsyncDependencies() <<
']';
896 if (hasClusterSize()) {
897 p <<
' ' << getClustersKeyword();
899 getClusterSizeOperandValues().value(),
900 getClusterIds().value());
902 p <<
' ' << getBlocksKeyword();
905 p <<
' ' << getThreadsKeyword();
908 if (getDynamicSharedMemorySize())
909 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
910 << getDynamicSharedMemorySize();
913 StringRef moduleAttrName = getModuleAttrName();
914 if (
auto module = getModule()) {
915 p <<
' ' << moduleAttrName <<
'(';
920 StringRef functionAttrName = getFunctionAttrName();
921 if (
auto function = getFunction()) {
922 p <<
' ' << functionAttrName <<
'(';
934 LaunchOp::getOperandSegmentSizeAttr(),
935 getWorkgroupAttributionsAttrName(),
936 moduleAttrName, functionAttrName});
951 assert(
indices.size() == 3 &&
"space for three indices expected");
958 if (args.size() != 3) {
960 << keyword <<
" expects 3 arguments, but got " << args.size();
962 std::move(args.begin(), args.end(),
indices.begin());
964 for (
int i = 0; i < 3; ++i) {
986ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
988 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
989 sizes(LaunchOp::kNumConfigOperands);
992 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
993 LaunchOp::kNumConfigRegionAttributes);
996 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1004 if (!asyncTokenType)
1007 "gpu.launch requires 'async' keyword to return a value");
1008 result.types.push_back(asyncTokenType);
1011 bool hasCluster =
false;
1015 regionArgs.resize(18);
1017 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1018 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1024 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1025 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1033 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword()) ||
1035 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1036 LaunchOp::getBlocksKeyword()) ||
1039 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1040 LaunchOp::getThreadsKeyword()) ||
1045 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1046 bool hasDynamicSharedMemorySize =
false;
1048 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1049 hasDynamicSharedMemorySize =
true;
1058 StringRef moduleAttrName = getModuleAttrName(
result.name);
1060 FlatSymbolRefAttr moduleSymbol;
1068 StringRef functionAttrName = getFunctionAttrName(
result.name);
1070 FlatSymbolRefAttr funcSymbol;
1082 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1083 LaunchOp::kNumConfigRegionAttributes + 6, index);
1085 SmallVector<OpAsmParser::Argument> regionArguments;
1086 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1087 OpAsmParser::Argument arg;
1088 arg.
ssaName = std::get<0>(ssaValueAndType);
1089 arg.
type = std::get<1>(ssaValueAndType);
1090 regionArguments.push_back(arg);
1101 unsigned numWorkgroupAttrs = regionArguments.size() -
1102 LaunchOp::kNumConfigRegionAttributes -
1103 (hasCluster ? 6 : 0);
1104 if (numWorkgroupAttrs != 0)
1105 result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(
result.name),
1116 Region *body =
result.addRegion();
1121 SmallVector<int32_t, 11> segmentSizes(11, 1);
1122 segmentSizes.front() = asyncDependencies.size();
1125 segmentSizes[7] = 0;
1126 segmentSizes[8] = 0;
1127 segmentSizes[9] = 0;
1129 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1130 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1144 bool simplified =
false;
1145 auto constPropIdUses = [&](
Value id,
Value size) {
1149 if (
id.getUses().empty())
1161 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1162 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1163 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1164 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1165 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1166 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1172void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1173 MLIRContext *context) {
1174 rewrites.
add<FoldLaunchArguments>(context);
1179BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1180 int64_t cur = getWorkgroupAttributions().value_or(0);
1181 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1182 return getBody().insertArgument(
1183 getNumConfigRegionAttributes() +
static_cast<unsigned>(cur), type, loc);
1188BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1191 return getBody().addArgument(type, loc);
1198void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1199 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1200 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1201 ValueRange kernelOperands, Type asyncTokenType,
1203 std::optional<KernelDim3> clusterSize) {
1204 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1205 "expected a symbol reference with a single nested reference");
1206 result.addOperands(asyncDependencies);
1213 if (clusterSize.has_value())
1214 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1215 if (dynamicSharedMemorySize)
1216 result.addOperands(dynamicSharedMemorySize);
1217 result.addOperands(kernelOperands);
1219 Properties &prop =
result.getOrAddProperties<Properties>();
1220 prop.kernel = kernelSymbol;
1221 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1223 llvm::fill(prop.operandSegmentSizes, 1);
1224 prop.operandSegmentSizes[0] = asyncDependencies.size();
1225 if (!clusterSize.has_value()) {
1226 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1227 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1228 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1230 prop.operandSegmentSizes[segmentSizesLen - 3] =
1231 dynamicSharedMemorySize ? 1 : 0;
1232 prop.operandSegmentSizes[segmentSizesLen - 2] =
1233 static_cast<int32_t
>(kernelOperands.size());
1234 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1237void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1239 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1240 ValueRange kernelOperands, Type asyncTokenType,
1242 std::optional<KernelDim3> clusterSize) {
1243 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1245 SymbolRefAttr::get(kernelModule.getNameAttr(),
1246 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1247 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1248 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1249 asyncDependencies, clusterSize);
1252void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1254 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1255 ValueRange kernelOperands, Value asyncObject,
1256 std::optional<KernelDim3> clusterSize) {
1260 if (clusterSize.has_value())
1261 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1262 if (dynamicSharedMemorySize)
1263 result.addOperands(dynamicSharedMemorySize);
1264 result.addOperands(kernelOperands);
1266 result.addOperands(asyncObject);
1267 Properties &prop =
result.getOrAddProperties<Properties>();
1268 prop.kernel = kernel;
1269 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1271 llvm::fill(prop.operandSegmentSizes, 1);
1272 prop.operandSegmentSizes[0] = 0;
1273 if (!clusterSize.has_value()) {
1274 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1275 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1276 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1278 prop.operandSegmentSizes[segmentSizesLen - 3] =
1279 dynamicSharedMemorySize ? 1 : 0;
1280 prop.operandSegmentSizes[segmentSizesLen - 2] =
1281 static_cast<int32_t
>(kernelOperands.size());
1282 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1285StringAttr LaunchFuncOp::getKernelModuleName() {
1289StringAttr LaunchFuncOp::getKernelName() {
1293unsigned LaunchFuncOp::getNumKernelOperands() {
1294 return getKernelOperands().size();
1297Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1298 return getKernelOperands()[i];
1301KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1302 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1303 return KernelDim3{operands[0], operands[1], operands[2]};
1306KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1307 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1308 return KernelDim3{operands[3], operands[4], operands[5]};
1311KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1312 assert(hasClusterSize() &&
1313 "cluster size is not set, check hasClusterSize() first");
1314 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1315 return KernelDim3{operands[6], operands[7], operands[8]};
1318LogicalResult LaunchFuncOp::verify() {
1319 auto module = (*this)->getParentOfType<ModuleOp>();
1321 return emitOpError(
"expected to belong to a module");
1323 if (!module->getAttrOfType<UnitAttr>(
1324 GPUDialect::getContainerModuleAttrName()))
1325 return emitOpError(
"expected the closest surrounding module to have the '" +
1326 GPUDialect::getContainerModuleAttrName() +
1329 if (hasClusterSize()) {
1330 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1333 <<
"expects types of the cluster dimensions must be the same";
1340LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1341 LaunchFuncOp launchOp = *
this;
1344 if (isa<GPUModuleOp>(table))
1349 if (!launchOp->getParentOp() ||
1350 launchOp->getParentOp()->getParentOp() != table)
1355 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1356 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1360 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1361 Operation *kernelContainer =
1363 if (!kernelContainer)
1365 <<
"kernel container '" << kernelContainerName.getValue()
1366 <<
"' is undefined";
1369 if (isa<BinaryOp>(kernelContainer))
1372 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1374 return launchOp.emitOpError()
1375 <<
"kernel module '" << kernelContainerName.getValue()
1376 <<
"' is undefined";
1380 kernelModule, launchOp.getKernelName());
1383 << launchOp.getKernel() <<
"' is undefined";
1384 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1385 if (!kernelConvertedFunction) {
1386 InFlightDiagnostic
diag = launchOp.emitOpError()
1387 <<
"referenced kernel '" << launchOp.getKernel()
1388 <<
"' is not a function";
1389 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
1393 if (!GPUDialect::isKernel(kernelFunc))
1394 return launchOp.emitOpError(
"kernel function is missing the '")
1395 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
1400 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1401 if (!kernelGPUFunction)
1404 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1405 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1406 if (expectedNumArguments != actualNumArguments)
1407 return launchOp.emitOpError(
"got ")
1408 << actualNumArguments <<
" kernel operands but expected "
1409 << expectedNumArguments;
1411 FunctionType functionType = kernelGPUFunction.getFunctionType();
1412 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
1413 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1414 return launchOp.emitOpError(
"type of function argument ")
1415 << i <<
" does not match";
1424 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1425 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1432 if (clusterValue.has_value()) {
1433 clusterXTy = clusterYTy = clusterZTy = dimTy;
1440 Type clusterYTy,
Type clusterZTy) {
1442 printer <<
": " << dimTy;
1452 auto parseElement = [&]() -> ParseResult {
1453 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1458 parseElement,
" in argument list");
1463 if (operands.empty())
1466 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1467 [&](
const auto &pair) {
1468 auto [operand, type] = pair;
1469 printer << operand <<
" : " << type;
1478void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1479 int32_t offset, int32_t width, ShuffleMode mode) {
1480 build(builder,
result, value,
1481 arith::ConstantOp::create(builder,
result.location,
1483 arith::ConstantOp::create(builder,
result.location,
1492LogicalResult RotateOp::verify() {
1493 uint32_t offset = getOffset();
1494 uint32_t width = getWidth();
1496 if (offset >= width) {
1497 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1510 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1514 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1515 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1519 if (!nextMemfence) {
1520 op.removeAddressSpacesAttr();
1524 if (*thisMemfence == *nextMemfence) {
1528 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1530 mergedSpaces.insert(attr);
1532 mergedSpaces.insert(attr);
1533 op.setAddressSpacesAttr(rewriter.
getArrayAttr(mergedSpaces.takeVector()));
1541void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1542 MLIRContext *context) {
1546void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1547 mlir::OperationState &odsState,
1548 std::optional<AddressSpace> addressSpace) {
1552 AddressSpaceAttr::get(odsBuilder.
getContext(), addressSpace.value()));
1553 build(odsBuilder, odsState, addressSpacesAttr);
1560void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1561 Value memrefToFence) {
1562 std::optional<AddressSpace> addrSpaceToFence;
1563 if (
auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.
getType()))
1564 if (
auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1565 memrefType.getMemorySpace()))
1566 addrSpaceToFence = addrSpaceAttr.getValue();
1567 return build(builder, odsState, addrSpaceToFence);
1576BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1577 int64_t cur = getWorkgroupAttributions().value_or(0);
1578 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1579 return getBody().insertArgument(
1580 getFunctionType().getNumInputs() +
static_cast<unsigned>(cur), type, loc);
1585BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1588 return getBody().addArgument(type, loc);
1591void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1592 StringRef name, FunctionType type,
1595 ArrayRef<NamedAttribute> attrs) {
1596 OpBuilder::InsertionGuard g(builder);
1600 result.addAttribute(getFunctionTypeAttrName(
result.name),
1601 TypeAttr::get(type));
1602 result.addAttribute(getWorkgroupAttributionsAttrName(
result.name),
1604 result.addAttributes(attrs);
1605 Region *body =
result.addRegion();
1609 for (Type argTy : type.getInputs())
1611 for (Type argTy : workgroupAttributions)
1613 for (Type argTy : privateAttributions)
1632 size_t existingArgs = args.size();
1639 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1644 attributionAttrs =
nullptr;
1650 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1651 if (!argument.attrs)
1654 attributionAttrsVec.push_back(argument.attrs);
1656 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1665ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1666 SmallVector<OpAsmParser::Argument> entryArgs;
1667 SmallVector<DictionaryAttr> resultAttrs;
1668 SmallVector<Type> resultTypes;
1672 StringAttr nameAttr;
1679 parser,
false, entryArgs, isVariadic, resultTypes,
1683 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1684 return parser.
emitError(signatureLocation)
1685 <<
"gpu.func requires named arguments";
1691 SmallVector<Type> argTypes;
1692 for (
auto &arg : entryArgs)
1693 argTypes.push_back(arg.
type);
1695 result.addAttribute(getFunctionTypeAttrName(
result.name),
1696 TypeAttr::get(type));
1699 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1700 getResAttrsAttrName(
result.name));
1702 Attribute workgroupAttributionAttrs;
1705 entryArgs, workgroupAttributionAttrs)))
1710 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1711 if (numWorkgroupAttrs != 0)
1713 GPUFuncOp::getWorkgroupAttributionsAttrName(
result.name),
1715 if (workgroupAttributionAttrs)
1716 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1717 workgroupAttributionAttrs);
1719 Attribute privateAttributionAttrs;
1722 entryArgs, privateAttributionAttrs)))
1724 if (privateAttributionAttrs)
1725 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1726 privateAttributionAttrs);
1730 result.addAttribute(GPUFuncOp::getKernelAttrName(
result.name),
1739 auto *body =
result.addRegion();
1743void GPUFuncOp::print(OpAsmPrinter &p) {
1747 FunctionType type = getFunctionType();
1753 getWorkgroupAttribAttrs().value_or(
nullptr));
1755 getPrivateAttribAttrs().value_or(
nullptr));
1757 p <<
' ' << getKernelKeyword();
1761 {getWorkgroupAttributionsAttrName(), getKernelAttrName(),
1762 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1763 getArgAttrsAttrName(), getResAttrsAttrName(),
1764 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1770 StringAttr attrName) {
1771 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1772 if (!allAttrs ||
index >= allAttrs.size())
1773 return DictionaryAttr();
1774 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1777DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1781DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1786 DictionaryAttr value, StringAttr attrName) {
1788 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1791 elements.append(allAttrs.begin(), allAttrs.end());
1792 while (elements.size() <=
index)
1793 elements.push_back(DictionaryAttr::get(ctx));
1795 elements[
index] = DictionaryAttr::get(ctx);
1797 elements[
index] = value;
1798 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1799 op->setAttr(attrName, newValue);
1802void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1803 DictionaryAttr value) {
1807void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1808 DictionaryAttr value) {
1813 StringAttr name, StringAttr attrsName) {
1817 return dict.get(name);
1820Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1822 assert(index < getNumWorkgroupAttributions() &&
1823 "index must map to a workgroup attribution");
1825 getWorkgroupAttribAttrsAttrName());
1828Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1830 assert(index < getNumPrivateAttributions() &&
1831 "index must map to a private attribution");
1833 getPrivateAttribAttrsAttrName());
1837 Attribute value, StringAttr attrsName) {
1842 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1845 bool mustSort =
true;
1846 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1847 if (elems[i].getName() == name) {
1850 std::swap(elems[i], elems[elems.size() - 1]);
1862 elems.emplace_back(name, value);
1865 DictionaryAttr::sortInPlace(elems);
1867 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1871void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1873 assert(index < getNumWorkgroupAttributions() &&
1874 "index must map to a workgroup attribution");
1876 getWorkgroupAttribAttrsAttrName());
1879void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1881 assert(index < getNumPrivateAttributions() &&
1882 "index must map to a private attribution");
1884 getPrivateAttribAttrsAttrName());
1887LogicalResult GPUFuncOp::verifyType() {
1888 if (isKernel() && getFunctionType().getNumResults() != 0)
1889 return emitOpError() <<
"expected void return type for kernel function";
1895LogicalResult GPUFuncOp::verifyBody() {
1897 return emitOpError() <<
"expected body with at least one block";
1898 unsigned numFuncArguments = getNumArguments();
1899 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1900 unsigned numBlockArguments = front().getNumArguments();
1901 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1903 << numFuncArguments + numWorkgroupAttributions
1904 <<
" arguments to body region";
1906 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1907 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1908 Type blockArgType = front().getArgument(i).getType();
1909 if (funcArgTypes[i] != blockArgType)
1910 return emitOpError() <<
"expected body region argument #" << i
1911 <<
" to be of type " << funcArgTypes[i] <<
", got "
1916 GPUDialect::getWorkgroupAddressSpace())) ||
1918 GPUDialect::getPrivateAddressSpace())))
1928LogicalResult gpu::ReturnOp::verify() {
1929 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1931 FunctionType funType = function.getFunctionType();
1933 if (funType.getNumResults() != getOperands().size())
1935 .append(
"expected ", funType.getNumResults(),
" result operands")
1936 .attachNote(function.getLoc())
1937 .append(
"return type declared here");
1939 for (
const auto &pair : llvm::enumerate(
1940 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1941 auto [type, operand] = pair.value();
1942 if (type != operand.getType())
1943 return emitOpError() <<
"unexpected type `" << operand.getType()
1944 <<
"' for operand #" << pair.index();
1953void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1955 Attribute offloadingHandler) {
1956 result.addRegion()->emplaceBlock();
1957 Properties &props =
result.getOrAddProperties<Properties>();
1959 props.targets = targets;
1961 props.offloadingHandler = offloadingHandler;
1964void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1965 StringRef name, ArrayRef<Attribute> targets,
1966 Attribute offloadingHandler) {
1967 build(builder,
result, name,
1972bool GPUModuleOp::hasTarget(Attribute
target) {
1973 if (
ArrayAttr targets = getTargetsAttr())
1974 return llvm::count(targets.getValue(),
target);
1978void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1979 ArrayAttr &targetsAttr = getProperties().targets;
1980 SmallVector<Attribute> targetsVector(targets);
1981 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
1984LogicalResult GPUModuleOp::verify() {
1985 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
1990 for (
auto target : targets) {
1991 if (
auto verifyTargetAttr =
1992 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
1993 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
2003void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
2004 Attribute offloadingHandler,
ArrayAttr objects) {
2005 auto &properties =
result.getOrAddProperties<Properties>();
2008 properties.objects = objects;
2009 if (offloadingHandler)
2010 properties.offloadingHandler = offloadingHandler;
2012 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
2015void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
2016 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2017 build(builder,
result, name, offloadingHandler,
2029 if (!offloadingHandler)
2036 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
2037 printer <<
'<' << offloadingHandler <<
'>';
2044LogicalResult MemcpyOp::verify() {
2045 auto srcType = getSrc().getType();
2046 auto dstType = getDst().getType();
2049 return emitOpError(
"arguments have incompatible element type");
2052 return emitOpError(
"arguments have incompatible shape");
2061struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
2062 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2064 LogicalResult matchAndRewrite(MemcpyOp op,
2065 PatternRewriter &rewriter)
const override {
2066 Value dest = op.getDst();
2075 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
2076 return user != op &&
2077 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2083 if (op.getAsyncDependencies().size() > 1 ||
2084 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2085 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2087 rewriter.
replaceOp(op, op.getAsyncDependencies());
2094void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2095 MLIRContext *context) {
2096 results.
add<EraseTrivialCopyOp>(context);
2103LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2104 auto srcType = getSrcMemref().getType();
2105 auto resType = getRes().getType();
2106 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2107 auto operand = resMatrixType.getOperand();
2108 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2110 if (!srcMemrefType.isLastDimUnitStride())
2112 "expected source memref most minor dim must have unit stride");
2114 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2115 return emitError(
"only AOp, BOp and COp can be loaded");
2124LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2125 auto srcType = getSrc().getType();
2126 auto dstType = getDstMemref().getType();
2127 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2128 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2130 if (!dstMemrefType.isLastDimUnitStride())
2132 "expected destination memref most minor dim must have unit stride");
2134 if (srcMatrixType.getOperand() !=
"COp")
2136 "expected the operand matrix being stored to have 'COp' operand type");
2145LogicalResult SubgroupMmaComputeOp::verify() {
2146 enum OperandMap {
A,
B,
C };
2147 SmallVector<MMAMatrixType, 3> opTypes;
2148 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2149 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2150 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2152 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2153 opTypes[C].getOperand() !=
"COp")
2154 return emitError(
"operands must be in the order AOp, BOp, COp");
2156 ArrayRef<int64_t> aShape, bShape, cShape;
2157 aShape = opTypes[
A].getShape();
2158 bShape = opTypes[
B].getShape();
2159 cShape = opTypes[
C].getShape();
2161 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2162 bShape[1] != cShape[1])
2163 return emitError(
"operand shapes do not satisfy matmul constraints");
2168LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2169 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2173LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2174 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2187struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2191 LogicalResult matchAndRewrite(WaitOp op,
2192 PatternRewriter &rewriter)
const final {
2193 auto predicate = [](Value value) {
2194 auto waitOp = value.getDefiningOp<WaitOp>();
2195 return waitOp && waitOp->getNumOperands() == 0;
2197 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2199 SmallVector<Value> validOperands;
2200 for (Value operand : op->getOperands()) {
2201 if (predicate(operand))
2203 validOperands.push_back(operand);
2205 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2217struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2221 LogicalResult matchAndRewrite(WaitOp op,
2222 PatternRewriter &rewriter)
const final {
2225 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2230 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2231 op.getAsyncToken()) {
2232 rewriter.
replaceOp(op, op.getAsyncDependencies());
2236 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2246void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2247 MLIRContext *context) {
2248 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2255LogicalResult AllocOp::verify() {
2256 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2262 unsigned numSymbols = 0;
2263 if (!memRefType.getLayout().isIdentity())
2264 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2265 if (getSymbolOperands().size() != numSymbols) {
2267 "symbol operand count does not equal memref symbol count");
2277struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2278 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2280 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2281 PatternRewriter &rewriter)
const override {
2282 std::optional<int64_t> index = dimOp.getConstantIndex();
2286 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2287 if (!memrefType || index.value() >= memrefType.getRank() ||
2288 !memrefType.isDynamicDim(index.value()))
2291 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2295 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2296 memrefType.getDynamicDimIndex(index.value()));
2297 rewriter.
replaceOp(dimOp, substituteOp);
2304void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2305 MLIRContext *context) {
2306 results.
add<SimplifyDimOfAllocOp>(context);
2314 Attribute
target, CompilationTarget format,
2315 StringAttr
object, DictionaryAttr properties,
2316 KernelTableAttr kernels) {
2318 return emitError() <<
"the target attribute cannot be null";
2319 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2321 return emitError() <<
"the target attribute must implement or promise the "
2322 "`gpu::TargetAttrInterface`";
2326ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2327 StringAttr &
object) {
2328 std::optional<CompilationTarget> formatResult;
2329 StringRef enumKeyword;
2332 formatResult = CompilationTarget::Fatbin;
2333 if (!formatResult &&
2335 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2337 return odsParser.
emitError(loc,
"expected an equal sign");
2339 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2340 FailureOr<StringAttr> objectResult =
2341 FieldParser<StringAttr>::parse(odsParser);
2342 if (
failed(objectResult))
2344 "failed to parse GPU_ObjectAttr parameter "
2345 "'object' which is to be a `StringAttr`");
2346 format = *formatResult;
2347 object = *objectResult;
2351void printObject(AsmPrinter &odsParser, CompilationTarget format,
2352 StringAttr
object) {
2353 if (format != CompilationTarget::Fatbin)
2354 odsParser << stringifyEnum(format) <<
" = ";
2355 odsParser << object;
2368 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2369 if (intAttr.getInt() < 0) {
2370 return emitError() <<
"the object index must be positive";
2372 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2374 <<
"the target attribute must be a GPU Target attribute";
2384LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2385 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2386 return emitOpError() <<
"must be inside an op with symbol table";
2388 MemRefType memrefType = getResultMemref().getType();
2390 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2392 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2393 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2395 if (memrefType.hasStaticShape()) {
2396 return emitOpError() <<
"result memref type must be memref<?xi8, "
2397 "#gpu.address_space<workgroup>>";
2406void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2407 p <<
"(" << getLaneid() <<
")";
2409 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2410 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2411 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2413 if (!getArgs().empty())
2414 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2415 if (!getResults().empty())
2416 p <<
" -> (" << getResults().getTypes() <<
')';
2420 !getResults().empty());
2424ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2425 OperationState &
result) {
2427 result.regions.reserve(1);
2428 Region *warpRegion =
result.addRegion();
2431 OpAsmParser::UnresolvedOperand laneId;
2443 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2450 llvm::SMLoc inputsOperandsLoc;
2451 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2452 SmallVector<Type> inputTypes;
2462 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2473 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2481void WarpExecuteOnLane0Op::getSuccessorRegions(
2482 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2489 regions.push_back(RegionSuccessor(&getWarpRegion()));
2492ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2495void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2498 build(builder,
result, resultTypes, laneId, warpSize,
2502void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2506 result.addOperands(laneId);
2507 result.addAttribute(getAttributeNames()[0],
2509 result.addTypes(resultTypes);
2510 result.addOperands(args);
2511 assert(args.size() == blockArgTypes.size());
2512 OpBuilder::InsertionGuard guard(builder);
2513 Region *warpRegion =
result.addRegion();
2515 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2524 if (expanded == distributed)
2526 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2527 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2528 if (!expandedVecType || !distributedVecType)
2529 return op->
emitOpError(
"expected vector type for distributed operands.");
2530 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2531 expandedVecType.getElementType() != distributedVecType.getElementType())
2533 "expected distributed vectors to have same rank and element type.");
2536 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2537 int64_t eDim = expandedVecType.getDimSize(i);
2538 int64_t dDim = distributedVecType.getDimSize(i);
2541 if (eDim % dDim != 0)
2543 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2544 <<
") to be a multipler of the distributed vector dimension ("
2546 scales[i] = eDim / dDim;
2548 if (llvm::product_of(scales) != warpSize)
2550 <<
"incompatible distribution dimensions from " << expandedVecType
2551 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2556LogicalResult WarpExecuteOnLane0Op::verify() {
2557 if (getArgs().size() != getWarpRegion().getNumArguments())
2559 "expected same number op arguments and block arguments.");
2560 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2562 return emitOpError(
"expected body to be terminated with 'gpu.yield'");
2563 if (yield.getNumOperands() != getNumResults())
2565 "expected same number of yield operands and return values.");
2566 int64_t warpSize = getWarpSize();
2567 for (
auto [regionArg, arg] :
2568 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2570 warpSize, getOperation())))
2573 for (
auto [yieldOperand,
result] :
2574 llvm::zip_equal(yield.getOperands(), getResults())) {
2576 warpSize, getOperation())))
2581bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2586gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2587 return cast<gpu::YieldOp>(getBody()->getTerminator());
2594void gpu::SubgroupBroadcastOp::inferResultRanges(
2595 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2596 setResultRange(getResult(), argRanges.front());
2600 switch (getBroadcastType()) {
2601 case BroadcastType::first_active_lane:
2605 case BroadcastType::specific_lane:
2609 llvm_unreachable(
"Unknown BroadcastType");
2612LogicalResult gpu::SubgroupBroadcastOp::verify() {
2613 switch (getBroadcastType()) {
2614 case BroadcastType::first_active_lane:
2617 <<
"lane can only be specified for `specific_lane` broadcast";
2619 case BroadcastType::specific_lane:
2622 <<
"lane must be specified for `specific_lane` broadcast";
2625 llvm_unreachable(
"Unknown BroadcastType");
2628OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2630 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2631 return prev.getResult();
2646KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2647 DictionaryAttr metadata) {
2648 assert(kernel &&
"invalid kernel");
2649 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2650 kernel.getAllArgAttrs(), metadata);
2655 FunctionOpInterface kernel,
2656 DictionaryAttr metadata) {
2657 assert(kernel &&
"invalid kernel");
2659 kernel.getAllArgAttrs(), metadata);
2663KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2666 NamedAttrList attrList;
2667 if (DictionaryAttr dict = getMetadata())
2670 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2676 StringAttr name, Type functionType,
2677 ArrayAttr argAttrs, DictionaryAttr metadata) {
2679 return emitError() <<
"the kernel name can't be empty";
2681 if (llvm::any_of(argAttrs, [](Attribute attr) {
2682 return !llvm::isa<DictionaryAttr>(attr);
2685 <<
"all attributes in the array must be a dictionary attribute";
2694KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2695 ArrayRef<KernelMetadataAttr> kernels,
2698 assert((!isSorted || llvm::is_sorted(kernels)) &&
2699 "expected a sorted kernel array");
2701 if (isSorted || llvm::is_sorted(kernels))
2702 return Base::get(context, kernels);
2704 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2705 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2706 return Base::get(context, kernelsTmp);
2709KernelTableAttr KernelTableAttr::getChecked(
2711 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2713 assert((!isSorted || llvm::is_sorted(kernels)) &&
2714 "expected a sorted kernel array");
2716 if (isSorted || llvm::is_sorted(kernels))
2717 return Base::getChecked(
emitError, context, kernels);
2719 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2720 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2721 return Base::getChecked(
emitError, context, kernelsTmp);
2726 ArrayRef<KernelMetadataAttr> kernels) {
2727 if (kernels.size() < 2)
2730 if (std::adjacent_find(kernels.begin(), kernels.end(),
2731 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2732 return l.getName() == r.getName();
2733 }) != kernels.end()) {
2734 return emitError() <<
"expected all kernels to be uniquely named";
2739KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2741 return found ? *iterator : KernelMetadataAttr();
2744KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2746 return found ? *iterator : KernelMetadataAttr();
2826 return CompilationTarget::Fatbin;
2829std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2831 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2832 llvm::StringSaver stringSaver(
options.first);
2838 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2839 opts.consume_front(
"\""), opts.consume_back(
"\"");
2840 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2841 opts.consume_front(
"'"), opts.consume_back(
"'");
2843 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2846 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2852std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2857std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2859 size_t startPos =
cmdOptions.find(startsWith);
2860 if (startPos == std::string::npos)
2871#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2872#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2874#define GET_ATTRDEF_CLASSES
2875#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2877#define GET_OP_CLASSES
2878#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2880#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....