36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
50#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
56int64_t GPUBlockMappingAttr::getMappingId()
const {
57 return static_cast<int64_t>(getBlock());
60bool GPUBlockMappingAttr::isLinearMapping()
const {
61 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
64int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
65 return isLinearMapping()
66 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
70int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
71 return static_cast<int64_t>(getWarpgroup());
74bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
75 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
78int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
79 return isLinearMapping()
80 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
84int64_t GPUWarpMappingAttr::getMappingId()
const {
85 return static_cast<int64_t>(getWarp());
88bool GPUWarpMappingAttr::isLinearMapping()
const {
89 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
92int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
93 return isLinearMapping()
94 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
98int64_t GPUThreadMappingAttr::getMappingId()
const {
99 return static_cast<int64_t>(getThread());
102bool GPUThreadMappingAttr::isLinearMapping()
const {
103 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
106int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
107 return isLinearMapping()
108 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
112int64_t GPULaneMappingAttr::getMappingId()
const {
113 return static_cast<int64_t>(getLane());
116bool GPULaneMappingAttr::isLinearMapping()
const {
117 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
120int64_t GPULaneMappingAttr::getRelativeIndex()
const {
121 return isLinearMapping()
122 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
126int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
138Value GPUMappingMaskAttr::createLogicalLinearMappingId(
142 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
143 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
144 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
145 filter = arith::SubIOp::create(
b, loc, filter, one);
146 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
147 return math::CtPopOp::create(
b, loc, filteredId);
160Value GPUMappingMaskAttr::createIsActiveIdPredicate(
164 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
165 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
166 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
167 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
168 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
169 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
173int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
174 return static_cast<int64_t>(getAddressSpace());
177bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
178 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
181int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
182 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
199 elementType, operand);
213 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
222 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
223 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
225 if (
shape.size() != 2)
226 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
230 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
239bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
242 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
243 return gpuAttr.getValue() == getWorkgroupAddressSpace();
247bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
248 Attribute memorySpace = type.getMemorySpace();
249 return isWorkgroupMemoryAddressSpace(memorySpace);
252bool GPUDialect::isConstantMemoryAddressSpace(
Attribute memorySpace) {
255 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
256 return gpuAttr.getValue() == getConstantAddressSpace();
260bool GPUDialect::hasConstantMemoryAddressSpace(MemRefType type) {
261 Attribute memorySpace = type.getMemorySpace();
262 return isConstantMemoryAddressSpace(memorySpace);
265bool GPUDialect::isKernel(Operation *op) {
266 if (
auto gpuFunc = dyn_cast<GPUFuncOp>(op))
267 return gpuFunc.isKernel();
268 return static_cast<bool>(
275struct GPUInlinerInterface :
public DialectInlinerInterface {
276 using DialectInlinerInterface::DialectInlinerInterface;
279 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
285void GPUDialect::initialize() {
286 addTypes<AsyncTokenType>();
287 addTypes<MMAMatrixType>();
288 addTypes<NamedBarrierType>();
289 addTypes<SparseDnTensorHandleType>();
290 addTypes<SparseSpMatHandleType>();
291 addTypes<SparseSpGEMMOpHandleType>();
294#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
297#define GET_ATTRDEF_LIST
298#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
300 addInterfaces<GPUInlinerInterface>();
301 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
303 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
304 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
305 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
306 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
307 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
308 declarePromisedInterfaces<memref::IndexedAccessOpInterface,
309 SubgroupMmaLoadMatrixOp,
310 SubgroupMmaStoreMatrixOp>();
316 return "sparse.dntensor_handle";
318 return "sparse.spmat_handle";
320 return "sparse.spgemmop_handle";
322 llvm_unreachable(
"unknown sparse handle kind");
326Type GPUDialect::parseType(DialectAsmParser &parser)
const {
334 if (keyword ==
"async.token")
337 if (keyword ==
"mma_matrix") {
345 SmallVector<int64_t> shape;
366 shape, elementType, operand);
369 if (keyword ==
"named_barrier")
384void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
387 .Case<NamedBarrierType>([&](Type) { os <<
"named_barrier"; })
388 .Case<SparseDnTensorHandleType>([&](Type) {
391 .Case<SparseSpMatHandleType>(
393 .Case<SparseSpGEMMOpHandleType>([&](Type) {
399 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
402 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
404 .DefaultUnreachable(
"unexpected 'gpu' type kind");
409 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
412 " must be a dense i32 array");
413 if (array.size() != 3)
415 " must contain exactly 3 elements");
419LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
420 NamedAttribute attr) {
421 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
423 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
425 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
427 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
428 attr.
getName() != getContainerModuleAttrName())
431 auto module = dyn_cast<ModuleOp>(op);
434 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
435 << ModuleOp::getOperationName() <<
'\'';
449 return parser.
emitError(loc,
"needs to be named when marked 'async'");
464 if (asyncDependencies.empty())
468 printer << llvm::interleaved_array(asyncDependencies);
496 p <<
' ' << keyword <<
'(';
497 llvm::interleaveComma(
498 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
499 BlockArgument v = pair.value();
500 p << v <<
" : " << v.
getType();
502 size_t attributionIndex = pair.index();
503 DictionaryAttr attrs;
504 if (attributes && attributionIndex < attributes.size())
505 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
515 gpu::AddressSpace memorySpace) {
516 for (
Value v : attributions) {
517 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
519 return op->
emitOpError() <<
"expected memref type in attribution";
524 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
527 if (addressSpace.getValue() != memorySpace)
529 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
530 <<
" in attribution";
541 using Kind = gpu::AllReduceOperation;
542 if (llvm::is_contained(
543 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
545 if (!isa<FloatType>(resType))
549 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
550 Kind::AND, Kind::OR, Kind::XOR},
552 if (!isa<IntegerType>(resType))
559LogicalResult gpu::AllReduceOp::verifyRegions() {
560 if (getBody().empty() != getOp().has_value())
561 return emitError(
"expected either an op attribute or a non-empty body");
562 if (!getBody().empty()) {
563 if (getBody().getNumArguments() != 2)
564 return emitError(
"expected two region arguments");
565 for (
auto argument : getBody().getArguments()) {
566 if (argument.getType() !=
getType())
567 return emitError(
"incorrect region argument type");
569 unsigned yieldCount = 0;
570 for (
Block &block : getBody()) {
571 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
572 if (yield.getNumOperands() != 1)
573 return emitError(
"expected one gpu.yield operand");
574 if (yield.getOperand(0).getType() !=
getType())
575 return emitError(
"incorrect gpu.yield type");
580 return emitError(
"expected gpu.yield op in region");
582 gpu::AllReduceOperation opName = *getOp();
584 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
585 <<
"` reduction operation is not compatible with type "
594 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
598 Region &body = launchOp.getBody();
599 assert(!body.
empty() &&
"Invalid region");
605OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
616 AllReduceOperationAttr &attr) {
619 std::optional<AllReduceOperation> op =
620 gpu::symbolizeAllReduceOperation(enumStr);
623 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
629 AllReduceOperationAttr attr) {
638LogicalResult gpu::SubgroupReduceOp::verify() {
640 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
641 if (vecTy.isScalable())
642 return emitOpError() <<
"is not compatible with scalable vector types";
644 elemType = vecTy.getElementType();
647 gpu::AllReduceOperation opName = getOp();
649 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
650 <<
"` reduction operation is not compatible with type "
654 auto clusterSize = getClusterSize();
656 uint32_t size = *clusterSize;
657 if (!llvm::isPowerOf2_32(size)) {
659 <<
" is not a power of two";
663 uint32_t stride = getClusterStride();
664 if (stride != 1 && !clusterSize) {
665 return emitOpError() <<
"cluster stride can only be specified if cluster "
668 if (!llvm::isPowerOf2_32(stride)) {
669 return emitOpError() <<
"cluster stride " << stride
670 <<
" is not a power of two";
676OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
677 if (getClusterSize() == 1)
694 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
698 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
716 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
724 if (!workgroupAttributions.empty())
726 getWorkgroupAttributionsAttrName(
result.name),
730 result.addOperands(asyncDependencies);
735 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
736 getBlockSizeY, getBlockSizeZ});
738 result.addOperands(clusterSizeX);
740 result.addOperands(clusterSizeY);
742 result.addOperands(clusterSizeZ);
743 if (dynamicSharedMemorySize)
744 result.addOperands(dynamicSharedMemorySize);
748 result.addAttribute(getModuleAttrName(
result.name), module);
750 result.addAttribute(getFunctionAttrName(
result.name), function);
758 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
761 for (
Type argTy : workgroupAttributions)
763 for (
Type argTy : privateAttributions)
767 segmentSizes.front() = asyncDependencies.size();
768 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
769 segmentSizes[7] = clusterSizeX ? 1 : 0;
770 segmentSizes[8] = clusterSizeY ? 1 : 0;
771 segmentSizes[9] = clusterSizeZ ? 1 : 0;
772 result.addAttribute(getOperandSegmentSizeAttr(),
777 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
778 auto args = getBody().getArguments();
783 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
784 auto args = getBody().getArguments();
789 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
790 auto args = getBody().getArguments();
795 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
796 auto args = getBody().getArguments();
797 return KernelDim3{args[9], args[10], args[11]};
800std::optional<KernelDim3> LaunchOp::getClusterIds() {
801 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
802 if (!hasClusterSize())
804 auto args = getBody().getArguments();
805 return KernelDim3{args[12], args[13], args[14]};
808std::optional<KernelDim3> LaunchOp::getClusterSize() {
809 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
810 if (!hasClusterSize())
812 auto args = getBody().getArguments();
813 return KernelDim3{args[15], args[16], args[17]};
816KernelDim3 LaunchOp::getGridSizeOperandValues() {
817 auto operands = getOperands().drop_front(getAsyncDependencies().size());
818 return KernelDim3{operands[0], operands[1], operands[2]};
821KernelDim3 LaunchOp::getBlockSizeOperandValues() {
822 auto operands = getOperands().drop_front(getAsyncDependencies().size());
823 return KernelDim3{operands[3], operands[4], operands[5]};
826std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
827 auto operands = getOperands().drop_front(getAsyncDependencies().size());
828 if (!hasClusterSize())
830 return KernelDim3{operands[6], operands[7], operands[8]};
833LogicalResult LaunchOp::verify() {
834 if (!(hasClusterSize()) &&
835 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
836 return emitOpError() <<
"cluster size must be all present";
840LogicalResult LaunchOp::verifyRegions() {
844 if (getBody().empty()) {
847 unsigned actualNumRegionArgs = getBody().getNumArguments();
848 unsigned expectedNumRegionArgs =
849 getNumConfigRegionAttributes() + getNumWorkgroupAttributions();
850 if (actualNumRegionArgs < expectedNumRegionArgs) {
852 << expectedNumRegionArgs <<
" region arguments, but got "
853 << actualNumRegionArgs;
858 GPUDialect::getWorkgroupAddressSpace())) ||
860 GPUDialect::getPrivateAddressSpace())))
865 for (
Block &block : getBody()) {
868 if (block.back().getNumSuccessors() != 0)
870 if (!isa<gpu::TerminatorOp>(&block.back())) {
873 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
874 "' or a terminator with successors")
875 .attachNote(getLoc())
876 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
880 if (getNumResults() == 0 && getAsyncToken())
881 return emitOpError(
"needs to be named when async keyword is specified");
892 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
893 p << size.
x <<
" = " << operands.
x <<
", ";
894 p << size.
y <<
" = " << operands.
y <<
", ";
895 p << size.
z <<
" = " << operands.
z <<
')';
898void LaunchOp::print(OpAsmPrinter &p) {
899 if (getAsyncToken()) {
901 if (!getAsyncDependencies().empty())
902 p <<
" [" << getAsyncDependencies() <<
']';
905 if (hasClusterSize()) {
906 p <<
' ' << getClustersKeyword();
908 getClusterSizeOperandValues().value(),
909 getClusterIds().value());
911 p <<
' ' << getBlocksKeyword();
914 p <<
' ' << getThreadsKeyword();
917 if (getDynamicSharedMemorySize())
918 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
919 << getDynamicSharedMemorySize();
922 StringRef moduleAttrName = getModuleAttrName();
923 if (
auto module = getModule()) {
924 p <<
' ' << moduleAttrName <<
'(';
929 StringRef functionAttrName = getFunctionAttrName();
930 if (
auto function = getFunction()) {
931 p <<
' ' << functionAttrName <<
'(';
936 if (getCooperative())
946 LaunchOp::getOperandSegmentSizeAttr(),
947 getWorkgroupAttributionsAttrName(),
948 getCooperativeAttrName(), moduleAttrName,
964 assert(
indices.size() == 3 &&
"space for three indices expected");
971 if (args.size() != 3) {
973 << keyword <<
" expects 3 arguments, but got " << args.size();
975 std::move(args.begin(), args.end(),
indices.begin());
977 for (
int i = 0; i < 3; ++i) {
999ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
1001 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1002 sizes(LaunchOp::kNumConfigOperands);
1005 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1006 LaunchOp::kNumConfigRegionAttributes);
1009 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1010 Type asyncTokenType;
1017 if (!asyncTokenType)
1020 "gpu.launch requires 'async' keyword to return a value");
1021 result.types.push_back(asyncTokenType);
1024 bool hasCluster =
false;
1028 regionArgs.resize(18);
1030 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1031 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1037 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1038 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1046 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword()) ||
1048 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1049 LaunchOp::getBlocksKeyword()) ||
1052 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1053 LaunchOp::getThreadsKeyword()) ||
1058 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1059 bool hasDynamicSharedMemorySize =
false;
1061 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1062 hasDynamicSharedMemorySize =
true;
1071 StringRef moduleAttrName = getModuleAttrName(
result.name);
1073 FlatSymbolRefAttr moduleSymbol;
1081 StringRef functionAttrName = getFunctionAttrName(
result.name);
1083 FlatSymbolRefAttr funcSymbol;
1099 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1100 LaunchOp::kNumConfigRegionAttributes + 6, index);
1102 SmallVector<OpAsmParser::Argument> regionArguments;
1103 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1104 OpAsmParser::Argument arg;
1105 arg.
ssaName = std::get<0>(ssaValueAndType);
1106 arg.
type = std::get<1>(ssaValueAndType);
1107 regionArguments.push_back(arg);
1118 unsigned numWorkgroupAttrs = regionArguments.size() -
1119 LaunchOp::kNumConfigRegionAttributes -
1120 (hasCluster ? 6 : 0);
1121 if (numWorkgroupAttrs != 0)
1122 result.addAttribute(LaunchOp::getWorkgroupAttributionsAttrName(
result.name),
1133 Region *body =
result.addRegion();
1138 SmallVector<int32_t, 11> segmentSizes(11, 1);
1139 segmentSizes.front() = asyncDependencies.size();
1142 segmentSizes[7] = 0;
1143 segmentSizes[8] = 0;
1144 segmentSizes[9] = 0;
1146 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1147 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1161 bool simplified =
false;
1162 auto constPropIdUses = [&](
Value id,
Value size) {
1166 if (
id.getUses().empty())
1178 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1179 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1180 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1181 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1182 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1183 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1189void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1190 MLIRContext *context) {
1191 rewrites.
add<FoldLaunchArguments>(context);
1196BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1197 int64_t cur = getWorkgroupAttributions().value_or(0);
1198 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1199 return getBody().insertArgument(
1200 getNumConfigRegionAttributes() +
static_cast<unsigned>(cur), type, loc);
1205BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1208 return getBody().addArgument(type, loc);
1215void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1216 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1217 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1218 ValueRange kernelOperands, Type asyncTokenType,
1220 std::optional<KernelDim3> clusterSize) {
1221 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1222 "expected a symbol reference with a single nested reference");
1223 result.addOperands(asyncDependencies);
1230 if (clusterSize.has_value())
1231 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1232 if (dynamicSharedMemorySize)
1233 result.addOperands(dynamicSharedMemorySize);
1234 result.addOperands(kernelOperands);
1236 Properties &prop =
result.getOrAddProperties<Properties>();
1237 prop.kernel = kernelSymbol;
1238 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1240 llvm::fill(prop.operandSegmentSizes, 1);
1241 prop.operandSegmentSizes[0] = asyncDependencies.size();
1242 if (!clusterSize.has_value()) {
1243 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1244 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1245 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1247 prop.operandSegmentSizes[segmentSizesLen - 3] =
1248 dynamicSharedMemorySize ? 1 : 0;
1249 prop.operandSegmentSizes[segmentSizesLen - 2] =
1250 static_cast<int32_t
>(kernelOperands.size());
1251 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1254void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1256 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1257 ValueRange kernelOperands, Type asyncTokenType,
1259 std::optional<KernelDim3> clusterSize) {
1260 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1262 SymbolRefAttr::get(kernelModule.getNameAttr(),
1263 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1264 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1265 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1266 asyncDependencies, clusterSize);
1269void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1271 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1272 ValueRange kernelOperands, Value asyncObject,
1273 std::optional<KernelDim3> clusterSize) {
1277 if (clusterSize.has_value())
1278 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1279 if (dynamicSharedMemorySize)
1280 result.addOperands(dynamicSharedMemorySize);
1281 result.addOperands(kernelOperands);
1283 result.addOperands(asyncObject);
1284 Properties &prop =
result.getOrAddProperties<Properties>();
1285 prop.kernel = kernel;
1286 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1288 llvm::fill(prop.operandSegmentSizes, 1);
1289 prop.operandSegmentSizes[0] = 0;
1290 if (!clusterSize.has_value()) {
1291 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1292 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1293 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1295 prop.operandSegmentSizes[segmentSizesLen - 3] =
1296 dynamicSharedMemorySize ? 1 : 0;
1297 prop.operandSegmentSizes[segmentSizesLen - 2] =
1298 static_cast<int32_t
>(kernelOperands.size());
1299 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1302StringAttr LaunchFuncOp::getKernelModuleName() {
1306StringAttr LaunchFuncOp::getKernelName() {
1310unsigned LaunchFuncOp::getNumKernelOperands() {
1311 return getKernelOperands().size();
1314Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1315 return getKernelOperands()[i];
1318KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1319 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1320 return KernelDim3{operands[0], operands[1], operands[2]};
1323KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1324 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1325 return KernelDim3{operands[3], operands[4], operands[5]};
1328KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1329 assert(hasClusterSize() &&
1330 "cluster size is not set, check hasClusterSize() first");
1331 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1332 return KernelDim3{operands[6], operands[7], operands[8]};
1335LogicalResult LaunchFuncOp::verify() {
1336 auto module = (*this)->getParentOfType<ModuleOp>();
1338 return emitOpError(
"expected to belong to a module");
1340 if (!module->getAttrOfType<UnitAttr>(
1341 GPUDialect::getContainerModuleAttrName()))
1342 return emitOpError(
"expected the closest surrounding module to have the '" +
1343 GPUDialect::getContainerModuleAttrName() +
1346 if (hasClusterSize()) {
1347 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1350 <<
"expects types of the cluster dimensions must be the same";
1353 if (!getAsyncDependencies().empty() && getAsyncObject())
1355 "cannot have both async dependencies and an explicit async object");
1361LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1362 LaunchFuncOp launchOp = *
this;
1365 if (isa<GPUModuleOp>(table))
1370 if (!launchOp->getParentOp() ||
1371 launchOp->getParentOp()->getParentOp() != table)
1376 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1377 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1381 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1382 Operation *kernelContainer =
1384 if (!kernelContainer)
1386 <<
"kernel container '" << kernelContainerName.getValue()
1387 <<
"' is undefined";
1390 if (isa<BinaryOp>(kernelContainer))
1393 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1395 return launchOp.emitOpError()
1396 <<
"kernel module '" << kernelContainerName.getValue()
1397 <<
"' is undefined";
1401 kernelModule, launchOp.getKernelName());
1404 << launchOp.getKernel() <<
"' is undefined";
1405 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1406 if (!kernelConvertedFunction) {
1407 InFlightDiagnostic
diag = launchOp.emitOpError()
1408 <<
"referenced kernel '" << launchOp.getKernel()
1409 <<
"' is not a function";
1410 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
1414 if (!GPUDialect::isKernel(kernelFunc))
1415 return launchOp.emitOpError(
"kernel function is missing the '")
1416 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
1421 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1422 if (!kernelGPUFunction)
1425 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1426 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1427 if (expectedNumArguments != actualNumArguments)
1428 return launchOp.emitOpError(
"got ")
1429 << actualNumArguments <<
" kernel operands but expected "
1430 << expectedNumArguments;
1432 FunctionType functionType = kernelGPUFunction.getFunctionType();
1433 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
1434 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1435 return launchOp.emitOpError(
"type of function argument ")
1436 << i <<
" does not match";
1445 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1446 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1453 if (clusterValue.has_value()) {
1454 clusterXTy = clusterYTy = clusterZTy = dimTy;
1461 Type clusterYTy,
Type clusterZTy) {
1463 printer <<
": " << dimTy;
1473 auto parseElement = [&]() -> ParseResult {
1474 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1479 parseElement,
" in argument list");
1484 if (operands.empty())
1487 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1488 [&](
const auto &pair) {
1489 auto [operand, type] = pair;
1490 printer << operand <<
" : " << type;
1499void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1500 int32_t offset, int32_t width, ShuffleMode mode) {
1501 build(builder,
result, value,
1502 arith::ConstantOp::create(builder,
result.location,
1504 arith::ConstantOp::create(builder,
result.location,
1513LogicalResult RotateOp::verify() {
1514 uint32_t offset = getOffset();
1515 uint32_t width = getWidth();
1517 if (offset >= width) {
1518 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1528LogicalResult BarrierOp::verify() {
1529 BarrierScope scope = getScope();
1531 if (getNamedBarrier() && scope != BarrierScope::Workgroup)
1532 return emitOpError(
"named barriers require workgroup scope");
1540 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1545 if (op.getScope() != nextOp.getScope())
1549 if (op.getNamedBarrier() != nextOp.getNamedBarrier())
1552 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1553 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1557 if (!nextMemfence) {
1558 op.removeAddressSpacesAttr();
1562 if (*thisMemfence == *nextMemfence) {
1566 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1568 mergedSpaces.insert(attr);
1570 mergedSpaces.insert(attr);
1571 op.setAddressSpacesAttr(rewriter.
getArrayAttr(mergedSpaces.takeVector()));
1579void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1580 MLIRContext *context) {
1584void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1585 mlir::OperationState &odsState,
1586 std::optional<AddressSpace> addressSpace) {
1590 AddressSpaceAttr::get(odsBuilder.
getContext(), addressSpace.value()));
1592 odsBuilder, odsState, addressSpacesAttr, Value{},
1593 BarrierScopeAttr::get(odsBuilder.
getContext(), BarrierScope::Workgroup));
1600void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1601 Value memrefToFence) {
1602 std::optional<AddressSpace> addrSpaceToFence;
1603 if (
auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.
getType()))
1604 if (
auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1605 memrefType.getMemorySpace()))
1606 addrSpaceToFence = addrSpaceAttr.getValue();
1607 return build(builder, odsState, addrSpaceToFence);
1616BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1617 int64_t cur = getWorkgroupAttributions().value_or(0);
1618 setWorkgroupAttributions(std::optional<int64_t>(cur + 1));
1619 return getBody().insertArgument(
1620 getFunctionType().getNumInputs() +
static_cast<unsigned>(cur), type, loc);
1625BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1628 return getBody().addArgument(type, loc);
1631void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1632 StringRef name, FunctionType type,
1635 ArrayRef<NamedAttribute> attrs) {
1636 OpBuilder::InsertionGuard g(builder);
1640 result.addAttribute(getFunctionTypeAttrName(
result.name),
1641 TypeAttr::get(type));
1642 result.addAttribute(getWorkgroupAttributionsAttrName(
result.name),
1644 result.addAttributes(attrs);
1645 Region *body =
result.addRegion();
1649 for (Type argTy : type.getInputs())
1651 for (Type argTy : workgroupAttributions)
1653 for (Type argTy : privateAttributions)
1672 size_t existingArgs = args.size();
1679 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1684 attributionAttrs =
nullptr;
1690 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1691 if (!argument.attrs)
1694 attributionAttrsVec.push_back(argument.attrs);
1696 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1705ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1706 SmallVector<OpAsmParser::Argument> entryArgs;
1707 SmallVector<DictionaryAttr> resultAttrs;
1708 SmallVector<Type> resultTypes;
1712 StringAttr nameAttr;
1719 parser,
false, entryArgs, isVariadic, resultTypes,
1723 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1724 return parser.
emitError(signatureLocation)
1725 <<
"gpu.func requires named arguments";
1731 SmallVector<Type> argTypes;
1732 for (
auto &arg : entryArgs)
1733 argTypes.push_back(arg.
type);
1735 result.addAttribute(getFunctionTypeAttrName(
result.name),
1736 TypeAttr::get(type));
1739 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1740 getResAttrsAttrName(
result.name));
1742 Attribute workgroupAttributionAttrs;
1745 entryArgs, workgroupAttributionAttrs)))
1750 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1751 if (numWorkgroupAttrs != 0)
1753 GPUFuncOp::getWorkgroupAttributionsAttrName(
result.name),
1755 if (workgroupAttributionAttrs)
1756 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1757 workgroupAttributionAttrs);
1759 Attribute privateAttributionAttrs;
1762 entryArgs, privateAttributionAttrs)))
1764 if (privateAttributionAttrs)
1765 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1766 privateAttributionAttrs);
1770 result.addAttribute(GPUFuncOp::getKernelAttrName(
result.name),
1779 auto *body =
result.addRegion();
1783void GPUFuncOp::print(OpAsmPrinter &p) {
1787 FunctionType type = getFunctionType();
1793 getWorkgroupAttribAttrs().value_or(
nullptr));
1795 getPrivateAttribAttrs().value_or(
nullptr));
1797 p <<
' ' << getKernelKeyword();
1801 {getWorkgroupAttributionsAttrName(), getKernelAttrName(),
1802 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1803 getArgAttrsAttrName(), getResAttrsAttrName(),
1804 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1810 StringAttr attrName) {
1811 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1812 if (!allAttrs ||
index >= allAttrs.size())
1813 return DictionaryAttr();
1814 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1817DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1821DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1826 DictionaryAttr value, StringAttr attrName) {
1828 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1831 elements.append(allAttrs.begin(), allAttrs.end());
1832 while (elements.size() <=
index)
1833 elements.push_back(DictionaryAttr::get(ctx));
1835 elements[
index] = DictionaryAttr::get(ctx);
1837 elements[
index] = value;
1838 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1839 op->setAttr(attrName, newValue);
1842void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1843 DictionaryAttr value) {
1847void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1848 DictionaryAttr value) {
1853 StringAttr name, StringAttr attrsName) {
1857 return dict.get(name);
1860Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1862 assert(index < getNumWorkgroupAttributions() &&
1863 "index must map to a workgroup attribution");
1865 getWorkgroupAttribAttrsAttrName());
1868Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1870 assert(index < getNumPrivateAttributions() &&
1871 "index must map to a private attribution");
1873 getPrivateAttribAttrsAttrName());
1877 Attribute value, StringAttr attrsName) {
1882 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1885 bool mustSort =
true;
1886 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1887 if (elems[i].getName() == name) {
1890 std::swap(elems[i], elems[elems.size() - 1]);
1902 elems.emplace_back(name, value);
1905 DictionaryAttr::sortInPlace(elems);
1907 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1911void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1913 assert(index < getNumWorkgroupAttributions() &&
1914 "index must map to a workgroup attribution");
1916 getWorkgroupAttribAttrsAttrName());
1919void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1921 assert(index < getNumPrivateAttributions() &&
1922 "index must map to a private attribution");
1924 getPrivateAttribAttrsAttrName());
1927LogicalResult GPUFuncOp::verifyType() {
1928 if (isKernel() && getFunctionType().getNumResults() != 0)
1929 return emitOpError() <<
"expected void return type for kernel function";
1935LogicalResult GPUFuncOp::verifyBody() {
1937 return emitOpError() <<
"expected body with at least one block";
1938 unsigned numFuncArguments = getNumArguments();
1939 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1940 unsigned numBlockArguments = front().getNumArguments();
1941 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1943 << numFuncArguments + numWorkgroupAttributions
1944 <<
" arguments to body region";
1946 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1947 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1948 Type blockArgType = front().getArgument(i).getType();
1949 if (funcArgTypes[i] != blockArgType)
1950 return emitOpError() <<
"expected body region argument #" << i
1951 <<
" to be of type " << funcArgTypes[i] <<
", got "
1956 GPUDialect::getWorkgroupAddressSpace())) ||
1958 GPUDialect::getPrivateAddressSpace())))
1968LogicalResult gpu::ReturnOp::verify() {
1969 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1971 FunctionType funType = function.getFunctionType();
1973 if (funType.getNumResults() != getOperands().size())
1975 .append(
"expected ", funType.getNumResults(),
" result operands")
1976 .attachNote(function.getLoc())
1977 .append(
"return type declared here");
1979 for (
const auto &pair : llvm::enumerate(
1980 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1981 auto [type, operand] = pair.value();
1982 if (type != operand.getType())
1983 return emitOpError() <<
"unexpected type `" << operand.getType()
1984 <<
"' for operand #" << pair.index();
1993void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1995 Attribute offloadingHandler) {
1996 result.addRegion()->emplaceBlock();
1997 Properties &props =
result.getOrAddProperties<Properties>();
1999 props.targets = targets;
2001 props.offloadingHandler = offloadingHandler;
2004void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
2005 StringRef name, ArrayRef<Attribute> targets,
2006 Attribute offloadingHandler) {
2007 build(builder,
result, name,
2012bool GPUModuleOp::hasTarget(Attribute
target) {
2013 if (
ArrayAttr targets = getTargetsAttr())
2014 return llvm::count(targets.getValue(),
target);
2018void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
2019 ArrayAttr &targetsAttr = getProperties().targets;
2020 SmallVector<Attribute> targetsVector(targets);
2021 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
2024LogicalResult GPUModuleOp::verify() {
2025 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
2030 for (
auto target : targets) {
2031 if (
auto verifyTargetAttr =
2032 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
2033 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
2043void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
2044 Attribute offloadingHandler,
ArrayAttr objects) {
2045 auto &properties =
result.getOrAddProperties<Properties>();
2048 properties.objects = objects;
2049 if (offloadingHandler)
2050 properties.offloadingHandler = offloadingHandler;
2052 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
2055void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
2056 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2057 build(builder,
result, name, offloadingHandler,
2069 if (!offloadingHandler)
2076 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
2077 printer <<
'<' << offloadingHandler <<
'>';
2084LogicalResult MemcpyOp::verify() {
2085 auto srcType = getSrc().getType();
2086 auto dstType = getDst().getType();
2089 return emitOpError(
"arguments have incompatible element type");
2092 return emitOpError(
"arguments have incompatible shape");
2101struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
2102 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2104 LogicalResult matchAndRewrite(MemcpyOp op,
2105 PatternRewriter &rewriter)
const override {
2106 Value dest = op.getDst();
2115 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
2116 return user != op &&
2117 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2123 if (op.getAsyncDependencies().size() > 1 ||
2124 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2125 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2127 rewriter.
replaceOp(op, op.getAsyncDependencies());
2134void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2135 MLIRContext *context) {
2136 results.
add<EraseTrivialCopyOp>(context);
2143LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2144 auto srcType = getSrcMemref().getType();
2145 auto resType = getRes().getType();
2146 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2147 auto operand = resMatrixType.getOperand();
2148 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2150 if (!srcMemrefType.isLastDimUnitStride())
2152 "expected source memref most minor dim must have unit stride");
2154 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2155 return emitError(
"only AOp, BOp and COp can be loaded");
2164LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2165 auto srcType = getSrc().getType();
2166 auto dstType = getDstMemref().getType();
2167 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2168 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2170 if (!dstMemrefType.isLastDimUnitStride())
2172 "expected destination memref most minor dim must have unit stride");
2174 if (srcMatrixType.getOperand() !=
"COp")
2176 "expected the operand matrix being stored to have 'COp' operand type");
2185LogicalResult SubgroupMmaComputeOp::verify() {
2186 enum OperandMap {
A,
B,
C };
2187 SmallVector<MMAMatrixType, 3> opTypes;
2188 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2189 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2190 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2192 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2193 opTypes[C].getOperand() !=
"COp")
2194 return emitError(
"operands must be in the order AOp, BOp, COp");
2196 ArrayRef<int64_t> aShape, bShape, cShape;
2197 aShape = opTypes[
A].getShape();
2198 bShape = opTypes[
B].getShape();
2199 cShape = opTypes[
C].getShape();
2201 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2202 bShape[1] != cShape[1])
2203 return emitError(
"operand shapes do not satisfy matmul constraints");
2208LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2209 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2213LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2214 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2227struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2231 LogicalResult matchAndRewrite(WaitOp op,
2232 PatternRewriter &rewriter)
const final {
2233 auto predicate = [](Value value) {
2234 auto waitOp = value.getDefiningOp<WaitOp>();
2235 return waitOp && waitOp->getNumOperands() == 0;
2237 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2239 SmallVector<Value> validOperands;
2240 for (Value operand : op->getOperands()) {
2241 if (predicate(operand))
2243 validOperands.push_back(operand);
2245 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2257struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2261 LogicalResult matchAndRewrite(WaitOp op,
2262 PatternRewriter &rewriter)
const final {
2265 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2270 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2271 op.getAsyncToken()) {
2272 rewriter.
replaceOp(op, op.getAsyncDependencies());
2276 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2286void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2287 MLIRContext *context) {
2288 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2295LogicalResult AllocOp::verify() {
2296 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2302 unsigned numSymbols = 0;
2303 if (!memRefType.getLayout().isIdentity())
2304 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2305 if (getSymbolOperands().size() != numSymbols) {
2307 "symbol operand count does not equal memref symbol count");
2317struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2318 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2320 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2321 PatternRewriter &rewriter)
const override {
2322 std::optional<int64_t> index = dimOp.getConstantIndex();
2326 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2327 if (!memrefType || index.value() >= memrefType.getRank() ||
2328 !memrefType.isDynamicDim(index.value()))
2331 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2335 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2336 memrefType.getDynamicDimIndex(index.value()));
2337 rewriter.
replaceOp(dimOp, substituteOp);
2344void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2345 MLIRContext *context) {
2346 results.
add<SimplifyDimOfAllocOp>(context);
2354 Attribute
target, CompilationTarget format,
2355 StringAttr
object, DictionaryAttr properties,
2356 KernelTableAttr kernels) {
2358 return emitError() <<
"the target attribute cannot be null";
2359 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2361 return emitError() <<
"the target attribute must implement or promise the "
2362 "`gpu::TargetAttrInterface`";
2366ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2367 StringAttr &
object) {
2368 std::optional<CompilationTarget> formatResult;
2369 StringRef enumKeyword;
2372 formatResult = CompilationTarget::Fatbin;
2373 if (!formatResult &&
2375 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2377 return odsParser.
emitError(loc,
"expected an equal sign");
2379 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2380 FailureOr<StringAttr> objectResult =
2381 FieldParser<StringAttr>::parse(odsParser);
2382 if (
failed(objectResult))
2384 "failed to parse GPU_ObjectAttr parameter "
2385 "'object' which is to be a `StringAttr`");
2386 format = *formatResult;
2387 object = *objectResult;
2391void printObject(AsmPrinter &odsParser, CompilationTarget format,
2392 StringAttr
object) {
2393 if (format != CompilationTarget::Fatbin)
2394 odsParser << stringifyEnum(format) <<
" = ";
2395 odsParser << object;
2408 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2409 if (intAttr.getInt() < 0) {
2410 return emitError() <<
"the object index must be positive";
2412 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2414 <<
"the target attribute must be a GPU Target attribute";
2424LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2425 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2426 return emitOpError() <<
"must be inside an op with symbol table";
2428 MemRefType memrefType = getResultMemref().getType();
2430 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2432 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2433 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2435 if (memrefType.hasStaticShape()) {
2436 return emitOpError() <<
"result memref type must be memref<?xi8, "
2437 "#gpu.address_space<workgroup>>";
2446void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2447 p <<
"(" << getLaneid() <<
")";
2449 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2450 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2451 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2453 if (!getArgs().empty())
2454 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2455 if (!getResults().empty())
2456 p <<
" -> (" << getResults().getTypes() <<
')';
2460 !getResults().empty());
2464ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2465 OperationState &
result) {
2467 result.regions.reserve(1);
2468 Region *warpRegion =
result.addRegion();
2471 OpAsmParser::UnresolvedOperand laneId;
2483 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2490 llvm::SMLoc inputsOperandsLoc;
2491 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2492 SmallVector<Type> inputTypes;
2502 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2513 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2521void WarpExecuteOnLane0Op::getSuccessorRegions(
2522 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2529 regions.push_back(RegionSuccessor(&getWarpRegion()));
2532ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2535void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2538 build(builder,
result, resultTypes, laneId, warpSize,
2542void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2546 result.addOperands(laneId);
2547 result.addAttribute(getAttributeNames()[0],
2549 result.addTypes(resultTypes);
2550 result.addOperands(args);
2551 assert(args.size() == blockArgTypes.size());
2552 OpBuilder::InsertionGuard guard(builder);
2553 Region *warpRegion =
result.addRegion();
2555 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2564 if (expanded == distributed)
2566 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2567 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2568 if (!expandedVecType || !distributedVecType)
2569 return op->
emitOpError(
"expected vector type for distributed operands.");
2570 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2571 expandedVecType.getElementType() != distributedVecType.getElementType())
2573 "expected distributed vectors to have same rank and element type.");
2576 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2577 int64_t eDim = expandedVecType.getDimSize(i);
2578 int64_t dDim = distributedVecType.getDimSize(i);
2581 if (eDim % dDim != 0)
2583 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2584 <<
") to be a multipler of the distributed vector dimension ("
2586 scales[i] = eDim / dDim;
2588 if (llvm::product_of(scales) != warpSize)
2590 <<
"incompatible distribution dimensions from " << expandedVecType
2591 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2596LogicalResult WarpExecuteOnLane0Op::verify() {
2597 if (getArgs().size() != getWarpRegion().getNumArguments())
2599 "expected same number op arguments and block arguments.");
2600 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2602 return emitOpError(
"expected body to be terminated with 'gpu.yield'");
2603 if (yield.getNumOperands() != getNumResults())
2605 "expected same number of yield operands and return values.");
2606 int64_t warpSize = getWarpSize();
2607 for (
auto [regionArg, arg] :
2608 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2610 warpSize, getOperation())))
2613 for (
auto [yieldOperand,
result] :
2614 llvm::zip_equal(yield.getOperands(), getResults())) {
2616 warpSize, getOperation())))
2621bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2626gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2627 return cast<gpu::YieldOp>(getBody()->getTerminator());
2634void gpu::SubgroupBroadcastOp::inferResultRanges(
2635 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2636 setResultRange(getResult(), argRanges.front());
2640 switch (getBroadcastType()) {
2641 case BroadcastType::first_active_lane:
2645 case BroadcastType::specific_lane:
2649 llvm_unreachable(
"Unknown BroadcastType");
2652LogicalResult gpu::SubgroupBroadcastOp::verify() {
2653 switch (getBroadcastType()) {
2654 case BroadcastType::first_active_lane:
2657 <<
"lane can only be specified for `specific_lane` broadcast";
2659 case BroadcastType::specific_lane:
2662 <<
"lane must be specified for `specific_lane` broadcast";
2665 llvm_unreachable(
"Unknown BroadcastType");
2668OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2670 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2671 return prev.getResult();
2686KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2687 DictionaryAttr metadata) {
2688 assert(kernel &&
"invalid kernel");
2689 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2690 kernel.getAllArgAttrs(), metadata);
2695 FunctionOpInterface kernel,
2696 DictionaryAttr metadata) {
2697 assert(kernel &&
"invalid kernel");
2699 kernel.getAllArgAttrs(), metadata);
2703KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2706 NamedAttrList attrList;
2707 if (DictionaryAttr dict = getMetadata())
2710 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2716 StringAttr name, Type functionType,
2717 ArrayAttr argAttrs, DictionaryAttr metadata) {
2719 return emitError() <<
"the kernel name can't be empty";
2721 if (llvm::any_of(argAttrs, [](Attribute attr) {
2722 return !llvm::isa<DictionaryAttr>(attr);
2725 <<
"all attributes in the array must be a dictionary attribute";
2734KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2735 ArrayRef<KernelMetadataAttr> kernels,
2738 assert((!isSorted || llvm::is_sorted(kernels)) &&
2739 "expected a sorted kernel array");
2741 if (isSorted || llvm::is_sorted(kernels))
2742 return Base::get(context, kernels);
2744 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2745 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2746 return Base::get(context, kernelsTmp);
2749KernelTableAttr KernelTableAttr::getChecked(
2751 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2753 assert((!isSorted || llvm::is_sorted(kernels)) &&
2754 "expected a sorted kernel array");
2756 if (isSorted || llvm::is_sorted(kernels))
2757 return Base::getChecked(
emitError, context, kernels);
2759 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2760 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2761 return Base::getChecked(
emitError, context, kernelsTmp);
2766 ArrayRef<KernelMetadataAttr> kernels) {
2767 if (kernels.size() < 2)
2770 if (std::adjacent_find(kernels.begin(), kernels.end(),
2771 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2772 return l.getName() == r.getName();
2773 }) != kernels.end()) {
2774 return emitError() <<
"expected all kernels to be uniquely named";
2779KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2781 return found ? *iterator : KernelMetadataAttr();
2784KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2786 return found ? *iterator : KernelMetadataAttr();
2866 return CompilationTarget::Fatbin;
2869std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2871 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2872 llvm::StringSaver stringSaver(
options.first);
2878 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2879 opts.consume_front(
"\""), opts.consume_back(
"\"");
2880 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2881 opts.consume_front(
"'"), opts.consume_back(
"'");
2883 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2886 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2892std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2897std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2899 size_t startPos =
cmdOptions.find(startsWith);
2900 if (startPos == std::string::npos)
2911#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2912#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2914#define GET_ATTRDEF_CLASSES
2915#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2917#define GET_OP_CLASSES
2918#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2920#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....