36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
55int64_t GPUBlockMappingAttr::getMappingId()
const {
56 return static_cast<int64_t>(getBlock());
59bool GPUBlockMappingAttr::isLinearMapping()
const {
60 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
63int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
64 return isLinearMapping()
65 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
69int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
70 return static_cast<int64_t>(getWarpgroup());
73bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
74 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
78 return isLinearMapping()
79 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
83int64_t GPUWarpMappingAttr::getMappingId()
const {
84 return static_cast<int64_t>(getWarp());
87bool GPUWarpMappingAttr::isLinearMapping()
const {
88 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
91int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
92 return isLinearMapping()
93 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
97int64_t GPUThreadMappingAttr::getMappingId()
const {
98 return static_cast<int64_t>(getThread());
101bool GPUThreadMappingAttr::isLinearMapping()
const {
102 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
105int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
106 return isLinearMapping()
107 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
111int64_t GPULaneMappingAttr::getMappingId()
const {
112 return static_cast<int64_t>(getLane());
115bool GPULaneMappingAttr::isLinearMapping()
const {
116 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
119int64_t GPULaneMappingAttr::getRelativeIndex()
const {
120 return isLinearMapping()
121 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
141 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(
b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
146 return math::CtPopOp::create(
b, loc, filteredId);
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
163 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
172int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
173 return static_cast<int64_t>(getAddressSpace());
176bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
177 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
181 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
198 elementType, operand);
212 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
221 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
222 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
224 if (
shape.size() != 2)
225 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
229 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
238bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
241 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
251bool GPUDialect::isKernel(
Operation *op) {
252 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
259struct GPUInlinerInterface :
public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
263 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
286 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
287 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
288 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
289 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
290 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
296 return "sparse.dntensor_handle";
298 return "sparse.spmat_handle";
300 return "sparse.spgemmop_handle";
302 llvm_unreachable(
"unknown sparse handle kind");
306Type GPUDialect::parseType(DialectAsmParser &parser)
const {
314 if (keyword ==
"async.token")
317 if (keyword ==
"mma_matrix") {
325 SmallVector<int64_t> shape;
346 shape, elementType, operand);
361void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
364 .Case<SparseDnTensorHandleType>([&](Type) {
367 .Case<SparseSpMatHandleType>(
369 .Case<SparseSpGEMMOpHandleType>([&](Type) {
375 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
378 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
380 .DefaultUnreachable(
"unexpected 'gpu' type kind");
385 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
388 " must be a dense i32 array");
389 if (array.size() != 3)
391 " must contain exactly 3 elements");
395LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
396 NamedAttribute attr) {
397 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
399 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
401 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
403 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
404 attr.
getName() != getContainerModuleAttrName())
407 auto module = dyn_cast<ModuleOp>(op);
410 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
411 << ModuleOp::getOperationName() <<
'\'';
413 auto walkResult =
module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
416 if (!launchOp->getParentOp() ||
417 launchOp->getParentOp()->getParentOp() != module)
422 if (!launchOp->getAttrOfType<SymbolRefAttr>(
423 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
427 StringAttr kernelContainerName = launchOp.getKernelModuleName();
428 Operation *kernelContainer =
module.lookupSymbol(kernelContainerName);
429 if (!kernelContainer)
430 return launchOp.emitOpError()
431 <<
"kernel container '" << kernelContainerName.getValue()
435 if (isa<BinaryOp>(kernelContainer))
438 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
440 return launchOp.emitOpError()
441 <<
"kernel module '" << kernelContainerName.getValue()
445 Operation *kernelFunc =
module.lookupSymbol(launchOp.getKernelAttr());
447 return launchOp.emitOpError(
"kernel function '")
448 << launchOp.getKernel() <<
"' is undefined";
449 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
450 if (!kernelConvertedFunction) {
451 InFlightDiagnostic
diag = launchOp.emitOpError()
452 <<
"referenced kernel '" << launchOp.getKernel()
453 <<
"' is not a function";
454 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
459 GPUDialect::getKernelFuncAttrName()))
460 return launchOp.emitOpError(
"kernel function is missing the '")
461 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
466 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
467 if (!kernelGPUFunction)
470 unsigned actualNumArguments = launchOp.getNumKernelOperands();
471 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
472 if (expectedNumArguments != actualNumArguments)
473 return launchOp.emitOpError(
"got ")
474 << actualNumArguments <<
" kernel operands but expected "
475 << expectedNumArguments;
477 auto functionType = kernelGPUFunction.getFunctionType();
478 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
479 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
480 return launchOp.emitOpError(
"type of function argument ")
481 << i <<
" does not match";
488 return walkResult.wasInterrupted() ? failure() :
success();
501 return parser.
emitError(loc,
"needs to be named when marked 'async'");
516 if (asyncDependencies.empty())
520 printer << llvm::interleaved_array(asyncDependencies);
548 p <<
' ' << keyword <<
'(';
549 llvm::interleaveComma(
550 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
551 BlockArgument v = pair.value();
552 p << v <<
" : " << v.
getType();
554 size_t attributionIndex = pair.index();
555 DictionaryAttr attrs;
556 if (attributes && attributionIndex < attributes.size())
557 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
567 gpu::AddressSpace memorySpace) {
568 for (
Value v : attributions) {
569 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
571 return op->
emitOpError() <<
"expected memref type in attribution";
576 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
579 if (addressSpace.getValue() != memorySpace)
581 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
582 <<
" in attribution";
593 using Kind = gpu::AllReduceOperation;
594 if (llvm::is_contained(
595 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
597 if (!isa<FloatType>(resType))
601 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
602 Kind::AND, Kind::OR, Kind::XOR},
604 if (!isa<IntegerType>(resType))
611LogicalResult gpu::AllReduceOp::verifyRegions() {
612 if (getBody().empty() != getOp().has_value())
613 return emitError(
"expected either an op attribute or a non-empty body");
614 if (!getBody().empty()) {
615 if (getBody().getNumArguments() != 2)
616 return emitError(
"expected two region arguments");
617 for (
auto argument : getBody().getArguments()) {
618 if (argument.getType() !=
getType())
619 return emitError(
"incorrect region argument type");
621 unsigned yieldCount = 0;
622 for (
Block &block : getBody()) {
623 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
624 if (yield.getNumOperands() != 1)
625 return emitError(
"expected one gpu.yield operand");
626 if (yield.getOperand(0).getType() !=
getType())
627 return emitError(
"incorrect gpu.yield type");
632 return emitError(
"expected gpu.yield op in region");
634 gpu::AllReduceOperation opName = *getOp();
636 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
637 <<
"` reduction operation is not compatible with type "
646 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
650 Region &body = launchOp.getBody();
651 assert(!body.
empty() &&
"Invalid region");
657OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
668 AllReduceOperationAttr &attr) {
671 std::optional<AllReduceOperation> op =
672 gpu::symbolizeAllReduceOperation(enumStr);
675 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
681 AllReduceOperationAttr attr) {
690LogicalResult gpu::SubgroupReduceOp::verify() {
692 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
693 if (vecTy.isScalable())
694 return emitOpError() <<
"is not compatible with scalable vector types";
696 elemType = vecTy.getElementType();
699 gpu::AllReduceOperation opName = getOp();
701 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
702 <<
"` reduction operation is not compatible with type "
706 auto clusterSize = getClusterSize();
708 uint32_t size = *clusterSize;
709 if (!llvm::isPowerOf2_32(size)) {
711 <<
" is not a power of two";
715 uint32_t stride = getClusterStride();
716 if (stride != 1 && !clusterSize) {
717 return emitOpError() <<
"cluster stride can only be specified if cluster "
720 if (!llvm::isPowerOf2_32(stride)) {
721 return emitOpError() <<
"cluster stride " << stride
722 <<
" is not a power of two";
728OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
729 if (getClusterSize() == 1)
746 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
750 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
768 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
778 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
782 result.addOperands(asyncDependencies);
787 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
788 getBlockSizeY, getBlockSizeZ});
790 result.addOperands(clusterSizeX);
792 result.addOperands(clusterSizeY);
794 result.addOperands(clusterSizeZ);
795 if (dynamicSharedMemorySize)
796 result.addOperands(dynamicSharedMemorySize);
800 result.addAttribute(getModuleAttrName(
result.name), module);
802 result.addAttribute(getFunctionAttrName(
result.name), function);
810 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
813 for (
Type argTy : workgroupAttributions)
815 for (
Type argTy : privateAttributions)
819 segmentSizes.front() = asyncDependencies.size();
820 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
821 segmentSizes[7] = clusterSizeX ? 1 : 0;
822 segmentSizes[8] = clusterSizeY ? 1 : 0;
823 segmentSizes[9] = clusterSizeZ ? 1 : 0;
824 result.addAttribute(getOperandSegmentSizeAttr(),
829 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
830 auto args = getBody().getArguments();
835 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
836 auto args = getBody().getArguments();
841 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
842 auto args = getBody().getArguments();
847 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
848 auto args = getBody().getArguments();
849 return KernelDim3{args[9], args[10], args[11]};
852std::optional<KernelDim3> LaunchOp::getClusterIds() {
853 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
854 if (!hasClusterSize())
856 auto args = getBody().getArguments();
857 return KernelDim3{args[12], args[13], args[14]};
860std::optional<KernelDim3> LaunchOp::getClusterSize() {
861 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
862 if (!hasClusterSize())
864 auto args = getBody().getArguments();
865 return KernelDim3{args[15], args[16], args[17]};
868KernelDim3 LaunchOp::getGridSizeOperandValues() {
869 auto operands = getOperands().drop_front(getAsyncDependencies().size());
870 return KernelDim3{operands[0], operands[1], operands[2]};
873KernelDim3 LaunchOp::getBlockSizeOperandValues() {
874 auto operands = getOperands().drop_front(getAsyncDependencies().size());
875 return KernelDim3{operands[3], operands[4], operands[5]};
878std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
879 auto operands = getOperands().drop_front(getAsyncDependencies().size());
880 if (!hasClusterSize())
882 return KernelDim3{operands[6], operands[7], operands[8]};
885LogicalResult LaunchOp::verify() {
886 if (!(hasClusterSize()) &&
887 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
888 return emitOpError() <<
"cluster size must be all present";
892LogicalResult LaunchOp::verifyRegions() {
896 if (getBody().empty()) {
899 if (getBody().getNumArguments() <
900 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
901 return emitOpError(
"unexpected number of region arguments");
906 GPUDialect::getWorkgroupAddressSpace())) ||
908 GPUDialect::getPrivateAddressSpace())))
913 for (
Block &block : getBody()) {
916 if (block.back().getNumSuccessors() != 0)
918 if (!isa<gpu::TerminatorOp>(&block.back())) {
921 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
922 "' or a terminator with successors")
923 .attachNote(getLoc())
924 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
928 if (getNumResults() == 0 && getAsyncToken())
929 return emitOpError(
"needs to be named when async keyword is specified");
940 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
941 p << size.
x <<
" = " << operands.
x <<
", ";
942 p << size.
y <<
" = " << operands.
y <<
", ";
943 p << size.
z <<
" = " << operands.
z <<
')';
946void LaunchOp::print(OpAsmPrinter &p) {
947 if (getAsyncToken()) {
949 if (!getAsyncDependencies().empty())
950 p <<
" [" << getAsyncDependencies() <<
']';
953 if (hasClusterSize()) {
954 p <<
' ' << getClustersKeyword();
956 getClusterSizeOperandValues().value(),
957 getClusterIds().value());
959 p <<
' ' << getBlocksKeyword();
962 p <<
' ' << getThreadsKeyword();
965 if (getDynamicSharedMemorySize())
966 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
967 << getDynamicSharedMemorySize();
970 StringRef moduleAttrName = getModuleAttrName();
971 if (
auto module = getModule()) {
972 p <<
' ' << moduleAttrName <<
'(';
977 StringRef functionAttrName = getFunctionAttrName();
978 if (
auto function = getFunction()) {
979 p <<
' ' << functionAttrName <<
'(';
991 LaunchOp::getOperandSegmentSizeAttr(),
992 getNumWorkgroupAttributionsAttrName(),
993 moduleAttrName, functionAttrName});
1007 StringRef keyword) {
1008 assert(
indices.size() == 3 &&
"space for three indices expected");
1015 if (args.size() != 3) {
1017 << keyword <<
" expects 3 arguments, but got " << args.size();
1019 std::move(args.begin(), args.end(),
indices.begin());
1021 for (
int i = 0; i < 3; ++i) {
1043ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
1045 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
1046 sizes(LaunchOp::kNumConfigOperands);
1049 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1050 LaunchOp::kNumConfigRegionAttributes);
1053 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1054 Type asyncTokenType;
1061 if (!asyncTokenType)
1064 "gpu.launch requires 'async' keyword to return a value");
1065 result.types.push_back(asyncTokenType);
1068 bool hasCluster =
false;
1072 regionArgs.resize(18);
1074 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1075 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1081 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1082 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1090 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword()) ||
1092 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1093 LaunchOp::getBlocksKeyword()) ||
1096 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1097 LaunchOp::getThreadsKeyword()) ||
1102 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1103 bool hasDynamicSharedMemorySize =
false;
1105 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1106 hasDynamicSharedMemorySize =
true;
1115 StringRef moduleAttrName = getModuleAttrName(
result.name);
1117 FlatSymbolRefAttr moduleSymbol;
1125 StringRef functionAttrName = getFunctionAttrName(
result.name);
1127 FlatSymbolRefAttr funcSymbol;
1142 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1143 LaunchOp::kNumConfigRegionAttributes + 6, index);
1145 SmallVector<OpAsmParser::Argument> regionArguments;
1146 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1147 OpAsmParser::Argument arg;
1148 arg.
ssaName = std::get<0>(ssaValueAndType);
1149 arg.
type = std::get<1>(ssaValueAndType);
1150 regionArguments.push_back(arg);
1161 unsigned numWorkgroupAttrs = regionArguments.size() -
1162 LaunchOp::kNumConfigRegionAttributes -
1163 (hasCluster ? 6 : 0);
1164 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1175 Region *body =
result.addRegion();
1180 SmallVector<int32_t, 11> segmentSizes(11, 1);
1181 segmentSizes.front() = asyncDependencies.size();
1184 segmentSizes[7] = 0;
1185 segmentSizes[8] = 0;
1186 segmentSizes[9] = 0;
1188 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1189 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1203 bool simplified =
false;
1204 auto constPropIdUses = [&](
Value id,
Value size) {
1208 if (
id.getUses().empty())
1220 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1221 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1222 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1223 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1224 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1225 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1231void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1232 MLIRContext *context) {
1233 rewrites.
add<FoldLaunchArguments>(context);
1238BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1239 auto attrName = getNumWorkgroupAttributionsAttrName();
1240 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1241 (*this)->setAttr(attrName,
1242 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1243 return getBody().insertArgument(
1244 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1249BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1252 return getBody().addArgument(type, loc);
1259void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1260 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1261 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1262 ValueRange kernelOperands, Type asyncTokenType,
1264 std::optional<KernelDim3> clusterSize) {
1265 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1266 "expected a symbol reference with a single nested reference");
1267 result.addOperands(asyncDependencies);
1274 if (clusterSize.has_value())
1275 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1276 if (dynamicSharedMemorySize)
1277 result.addOperands(dynamicSharedMemorySize);
1278 result.addOperands(kernelOperands);
1280 Properties &prop =
result.getOrAddProperties<Properties>();
1281 prop.kernel = kernelSymbol;
1282 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1284 llvm::fill(prop.operandSegmentSizes, 1);
1285 prop.operandSegmentSizes[0] = asyncDependencies.size();
1286 if (!clusterSize.has_value()) {
1287 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1288 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1289 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1291 prop.operandSegmentSizes[segmentSizesLen - 3] =
1292 dynamicSharedMemorySize ? 1 : 0;
1293 prop.operandSegmentSizes[segmentSizesLen - 2] =
1294 static_cast<int32_t
>(kernelOperands.size());
1295 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1298void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1300 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1301 ValueRange kernelOperands, Type asyncTokenType,
1303 std::optional<KernelDim3> clusterSize) {
1304 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1306 SymbolRefAttr::get(kernelModule.getNameAttr(),
1307 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1308 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1309 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1310 asyncDependencies, clusterSize);
1313void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1315 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1316 ValueRange kernelOperands, Value asyncObject,
1317 std::optional<KernelDim3> clusterSize) {
1321 if (clusterSize.has_value())
1322 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1323 if (dynamicSharedMemorySize)
1324 result.addOperands(dynamicSharedMemorySize);
1325 result.addOperands(kernelOperands);
1327 result.addOperands(asyncObject);
1328 Properties &prop =
result.getOrAddProperties<Properties>();
1329 prop.kernel = kernel;
1330 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1332 llvm::fill(prop.operandSegmentSizes, 1);
1333 prop.operandSegmentSizes[0] = 0;
1334 if (!clusterSize.has_value()) {
1335 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1336 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1337 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1339 prop.operandSegmentSizes[segmentSizesLen - 3] =
1340 dynamicSharedMemorySize ? 1 : 0;
1341 prop.operandSegmentSizes[segmentSizesLen - 2] =
1342 static_cast<int32_t
>(kernelOperands.size());
1343 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1346StringAttr LaunchFuncOp::getKernelModuleName() {
1350StringAttr LaunchFuncOp::getKernelName() {
1354unsigned LaunchFuncOp::getNumKernelOperands() {
1355 return getKernelOperands().size();
1358Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1359 return getKernelOperands()[i];
1362KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1363 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1364 return KernelDim3{operands[0], operands[1], operands[2]};
1367KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1368 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1369 return KernelDim3{operands[3], operands[4], operands[5]};
1372KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1373 assert(hasClusterSize() &&
1374 "cluster size is not set, check hasClusterSize() first");
1375 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1376 return KernelDim3{operands[6], operands[7], operands[8]};
1379LogicalResult LaunchFuncOp::verify() {
1380 auto module = (*this)->getParentOfType<ModuleOp>();
1382 return emitOpError(
"expected to belong to a module");
1384 if (!module->getAttrOfType<UnitAttr>(
1385 GPUDialect::getContainerModuleAttrName()))
1386 return emitOpError(
"expected the closest surrounding module to have the '" +
1387 GPUDialect::getContainerModuleAttrName() +
1390 if (hasClusterSize()) {
1391 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1394 <<
"expects types of the cluster dimensions must be the same";
1402 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1403 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1410 if (clusterValue.has_value()) {
1411 clusterXTy = clusterYTy = clusterZTy = dimTy;
1418 Type clusterYTy,
Type clusterZTy) {
1420 printer <<
": " << dimTy;
1430 auto parseElement = [&]() -> ParseResult {
1431 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1436 parseElement,
" in argument list");
1441 if (operands.empty())
1444 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1445 [&](
const auto &pair) {
1446 auto [operand, type] = pair;
1447 printer << operand <<
" : " << type;
1456void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1457 int32_t offset, int32_t width, ShuffleMode mode) {
1458 build(builder,
result, value,
1459 arith::ConstantOp::create(builder,
result.location,
1461 arith::ConstantOp::create(builder,
result.location,
1470LogicalResult RotateOp::verify() {
1471 uint32_t offset = getOffset();
1472 uint32_t width = getWidth();
1474 if (offset >= width) {
1475 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1488 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1492 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1493 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1497 if (!nextMemfence) {
1498 op.removeAddressSpacesAttr();
1502 if (*thisMemfence == *nextMemfence) {
1506 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1508 mergedSpaces.insert(attr);
1510 mergedSpaces.insert(attr);
1511 op.setAddressSpacesAttr(rewriter.
getArrayAttr(mergedSpaces.takeVector()));
1519void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1520 MLIRContext *context) {
1524void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1525 mlir::OperationState &odsState,
1526 std::optional<AddressSpace> addressSpace) {
1530 AddressSpaceAttr::get(odsBuilder.
getContext(), addressSpace.value()));
1531 build(odsBuilder, odsState, addressSpacesAttr);
1538void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1539 Value memrefToFence) {
1540 std::optional<AddressSpace> addrSpaceToFence;
1541 if (
auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.
getType()))
1542 if (
auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1543 memrefType.getMemorySpace()))
1544 addrSpaceToFence = addrSpaceAttr.getValue();
1545 return build(builder, odsState, addrSpaceToFence);
1554BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1555 auto attrName = getNumWorkgroupAttributionsAttrName();
1556 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1557 (*this)->setAttr(attrName,
1558 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1559 return getBody().insertArgument(
1560 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1565BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1568 return getBody().addArgument(type, loc);
1571void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1572 StringRef name, FunctionType type,
1575 ArrayRef<NamedAttribute> attrs) {
1576 OpBuilder::InsertionGuard g(builder);
1580 result.addAttribute(getFunctionTypeAttrName(
result.name),
1581 TypeAttr::get(type));
1582 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1584 result.addAttributes(attrs);
1585 Region *body =
result.addRegion();
1589 for (Type argTy : type.getInputs())
1591 for (Type argTy : workgroupAttributions)
1593 for (Type argTy : privateAttributions)
1612 size_t existingArgs = args.size();
1619 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1624 attributionAttrs =
nullptr;
1630 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1631 if (!argument.attrs)
1634 attributionAttrsVec.push_back(argument.attrs);
1636 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1645ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1646 SmallVector<OpAsmParser::Argument> entryArgs;
1647 SmallVector<DictionaryAttr> resultAttrs;
1648 SmallVector<Type> resultTypes;
1652 StringAttr nameAttr;
1659 parser,
false, entryArgs, isVariadic, resultTypes,
1663 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1664 return parser.
emitError(signatureLocation)
1665 <<
"gpu.func requires named arguments";
1671 SmallVector<Type> argTypes;
1672 for (
auto &arg : entryArgs)
1673 argTypes.push_back(arg.
type);
1675 result.addAttribute(getFunctionTypeAttrName(
result.name),
1676 TypeAttr::get(type));
1679 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1680 getResAttrsAttrName(
result.name));
1682 Attribute workgroupAttributionAttrs;
1685 entryArgs, workgroupAttributionAttrs)))
1690 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1691 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1693 if (workgroupAttributionAttrs)
1694 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1695 workgroupAttributionAttrs);
1697 Attribute privateAttributionAttrs;
1700 entryArgs, privateAttributionAttrs)))
1702 if (privateAttributionAttrs)
1703 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1704 privateAttributionAttrs);
1708 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1717 auto *body =
result.addRegion();
1721void GPUFuncOp::print(OpAsmPrinter &p) {
1725 FunctionType type = getFunctionType();
1731 getWorkgroupAttribAttrs().value_or(
nullptr));
1733 getPrivateAttribAttrs().value_or(
nullptr));
1735 p <<
' ' << getKernelKeyword();
1739 {getNumWorkgroupAttributionsAttrName(),
1740 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1741 getArgAttrsAttrName(), getResAttrsAttrName(),
1742 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1748 StringAttr attrName) {
1749 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1750 if (!allAttrs ||
index >= allAttrs.size())
1751 return DictionaryAttr();
1752 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1755DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1759DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1764 DictionaryAttr value, StringAttr attrName) {
1766 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1769 elements.append(allAttrs.begin(), allAttrs.end());
1770 while (elements.size() <=
index)
1771 elements.push_back(DictionaryAttr::get(ctx));
1773 elements[
index] = DictionaryAttr::get(ctx);
1775 elements[
index] = value;
1776 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1777 op->setAttr(attrName, newValue);
1780void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1781 DictionaryAttr value) {
1785void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1786 DictionaryAttr value) {
1791 StringAttr name, StringAttr attrsName) {
1795 return dict.get(name);
1798Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1800 assert(index < getNumWorkgroupAttributions() &&
1801 "index must map to a workgroup attribution");
1803 getWorkgroupAttribAttrsAttrName());
1806Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1808 assert(index < getNumPrivateAttributions() &&
1809 "index must map to a private attribution");
1811 getPrivateAttribAttrsAttrName());
1815 Attribute value, StringAttr attrsName) {
1820 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1823 bool mustSort =
true;
1824 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1825 if (elems[i].getName() == name) {
1828 std::swap(elems[i], elems[elems.size() - 1]);
1840 elems.emplace_back(name, value);
1843 DictionaryAttr::sortInPlace(elems);
1845 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1849void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1851 assert(index < getNumWorkgroupAttributions() &&
1852 "index must map to a workgroup attribution");
1854 getWorkgroupAttribAttrsAttrName());
1857void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1859 assert(index < getNumPrivateAttributions() &&
1860 "index must map to a private attribution");
1862 getPrivateAttribAttrsAttrName());
1865LogicalResult GPUFuncOp::verifyType() {
1866 if (isKernel() && getFunctionType().getNumResults() != 0)
1867 return emitOpError() <<
"expected void return type for kernel function";
1873LogicalResult GPUFuncOp::verifyBody() {
1875 return emitOpError() <<
"expected body with at least one block";
1876 unsigned numFuncArguments = getNumArguments();
1877 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1878 unsigned numBlockArguments = front().getNumArguments();
1879 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1881 << numFuncArguments + numWorkgroupAttributions
1882 <<
" arguments to body region";
1884 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1885 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1886 Type blockArgType = front().getArgument(i).getType();
1887 if (funcArgTypes[i] != blockArgType)
1888 return emitOpError() <<
"expected body region argument #" << i
1889 <<
" to be of type " << funcArgTypes[i] <<
", got "
1894 GPUDialect::getWorkgroupAddressSpace())) ||
1896 GPUDialect::getPrivateAddressSpace())))
1906LogicalResult gpu::ReturnOp::verify() {
1907 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1909 FunctionType funType = function.getFunctionType();
1911 if (funType.getNumResults() != getOperands().size())
1913 .append(
"expected ", funType.getNumResults(),
" result operands")
1914 .attachNote(function.getLoc())
1915 .append(
"return type declared here");
1917 for (
const auto &pair : llvm::enumerate(
1918 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1919 auto [type, operand] = pair.value();
1920 if (type != operand.getType())
1921 return emitOpError() <<
"unexpected type `" << operand.getType()
1922 <<
"' for operand #" << pair.index();
1931void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1933 Attribute offloadingHandler) {
1934 result.addRegion()->emplaceBlock();
1935 Properties &props =
result.getOrAddProperties<Properties>();
1937 props.targets = targets;
1939 props.offloadingHandler = offloadingHandler;
1942void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1943 StringRef name, ArrayRef<Attribute> targets,
1944 Attribute offloadingHandler) {
1945 build(builder,
result, name,
1950bool GPUModuleOp::hasTarget(Attribute
target) {
1951 if (
ArrayAttr targets = getTargetsAttr())
1952 return llvm::count(targets.getValue(),
target);
1956void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1957 ArrayAttr &targetsAttr = getProperties().targets;
1958 SmallVector<Attribute> targetsVector(targets);
1959 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
1962LogicalResult GPUModuleOp::verify() {
1963 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
1968 for (
auto target : targets) {
1969 if (
auto verifyTargetAttr =
1970 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
1971 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
1981void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1982 Attribute offloadingHandler,
ArrayAttr objects) {
1983 auto &properties =
result.getOrAddProperties<Properties>();
1986 properties.objects = objects;
1987 if (offloadingHandler)
1988 properties.offloadingHandler = offloadingHandler;
1990 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
1993void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1994 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1995 build(builder,
result, name, offloadingHandler,
2007 if (!offloadingHandler)
2014 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
2015 printer <<
'<' << offloadingHandler <<
'>';
2022LogicalResult MemcpyOp::verify() {
2023 auto srcType = getSrc().getType();
2024 auto dstType = getDst().getType();
2027 return emitOpError(
"arguments have incompatible element type");
2030 return emitOpError(
"arguments have incompatible shape");
2039struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
2040 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2042 LogicalResult matchAndRewrite(MemcpyOp op,
2043 PatternRewriter &rewriter)
const override {
2044 Value dest = op.getDst();
2053 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
2054 return user != op &&
2055 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2061 if (op.getAsyncDependencies().size() > 1 ||
2062 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2063 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2065 rewriter.
replaceOp(op, op.getAsyncDependencies());
2072void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2073 MLIRContext *context) {
2074 results.
add<EraseTrivialCopyOp>(context);
2081LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2082 auto srcType = getSrcMemref().getType();
2083 auto resType = getRes().getType();
2084 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2085 auto operand = resMatrixType.getOperand();
2086 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2088 if (!srcMemrefType.isLastDimUnitStride())
2090 "expected source memref most minor dim must have unit stride");
2092 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2093 return emitError(
"only AOp, BOp and COp can be loaded");
2102LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2103 auto srcType = getSrc().getType();
2104 auto dstType = getDstMemref().getType();
2105 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2106 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2108 if (!dstMemrefType.isLastDimUnitStride())
2110 "expected destination memref most minor dim must have unit stride");
2112 if (srcMatrixType.getOperand() !=
"COp")
2114 "expected the operand matrix being stored to have 'COp' operand type");
2123LogicalResult SubgroupMmaComputeOp::verify() {
2124 enum OperandMap {
A,
B,
C };
2125 SmallVector<MMAMatrixType, 3> opTypes;
2126 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2127 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2128 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2130 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2131 opTypes[C].getOperand() !=
"COp")
2132 return emitError(
"operands must be in the order AOp, BOp, COp");
2134 ArrayRef<int64_t> aShape, bShape, cShape;
2135 aShape = opTypes[
A].getShape();
2136 bShape = opTypes[
B].getShape();
2137 cShape = opTypes[
C].getShape();
2139 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2140 bShape[1] != cShape[1])
2141 return emitError(
"operand shapes do not satisfy matmul constraints");
2146LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2147 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2151LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2152 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2165struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2169 LogicalResult matchAndRewrite(WaitOp op,
2170 PatternRewriter &rewriter)
const final {
2171 auto predicate = [](Value value) {
2172 auto waitOp = value.getDefiningOp<WaitOp>();
2173 return waitOp && waitOp->getNumOperands() == 0;
2175 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2177 SmallVector<Value> validOperands;
2178 for (Value operand : op->getOperands()) {
2179 if (predicate(operand))
2181 validOperands.push_back(operand);
2183 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2195struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2199 LogicalResult matchAndRewrite(WaitOp op,
2200 PatternRewriter &rewriter)
const final {
2203 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2208 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2209 op.getAsyncToken()) {
2210 rewriter.
replaceOp(op, op.getAsyncDependencies());
2214 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2224void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2225 MLIRContext *context) {
2226 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2233LogicalResult AllocOp::verify() {
2234 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2240 unsigned numSymbols = 0;
2241 if (!memRefType.getLayout().isIdentity())
2242 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2243 if (getSymbolOperands().size() != numSymbols) {
2245 "symbol operand count does not equal memref symbol count");
2255struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2256 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2258 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2259 PatternRewriter &rewriter)
const override {
2260 std::optional<int64_t> index = dimOp.getConstantIndex();
2264 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2265 if (!memrefType || index.value() >= memrefType.getRank() ||
2266 !memrefType.isDynamicDim(index.value()))
2269 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2273 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2274 memrefType.getDynamicDimIndex(index.value()));
2275 rewriter.
replaceOp(dimOp, substituteOp);
2282void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2283 MLIRContext *context) {
2284 results.
add<SimplifyDimOfAllocOp>(context);
2292 Attribute
target, CompilationTarget format,
2293 StringAttr
object, DictionaryAttr properties,
2294 KernelTableAttr kernels) {
2296 return emitError() <<
"the target attribute cannot be null";
2297 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2299 return emitError() <<
"the target attribute must implement or promise the "
2300 "`gpu::TargetAttrInterface`";
2304ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2305 StringAttr &
object) {
2306 std::optional<CompilationTarget> formatResult;
2307 StringRef enumKeyword;
2310 formatResult = CompilationTarget::Fatbin;
2311 if (!formatResult &&
2313 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2315 return odsParser.
emitError(loc,
"expected an equal sign");
2317 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2318 FailureOr<StringAttr> objectResult =
2319 FieldParser<StringAttr>::parse(odsParser);
2320 if (
failed(objectResult))
2322 "failed to parse GPU_ObjectAttr parameter "
2323 "'object' which is to be a `StringAttr`");
2324 format = *formatResult;
2325 object = *objectResult;
2329void printObject(AsmPrinter &odsParser, CompilationTarget format,
2330 StringAttr
object) {
2331 if (format != CompilationTarget::Fatbin)
2332 odsParser << stringifyEnum(format) <<
" = ";
2333 odsParser << object;
2346 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2347 if (intAttr.getInt() < 0) {
2348 return emitError() <<
"the object index must be positive";
2350 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2352 <<
"the target attribute must be a GPU Target attribute";
2362LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2363 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2364 return emitOpError() <<
"must be inside an op with symbol table";
2366 MemRefType memrefType = getResultMemref().getType();
2368 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2370 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2371 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2373 if (memrefType.hasStaticShape()) {
2374 return emitOpError() <<
"result memref type must be memref<?xi8, "
2375 "#gpu.address_space<workgroup>>";
2384void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2385 p <<
"(" << getLaneid() <<
")";
2387 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2388 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2389 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2391 if (!getArgs().empty())
2392 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2393 if (!getResults().empty())
2394 p <<
" -> (" << getResults().getTypes() <<
')';
2398 !getResults().empty());
2402ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2403 OperationState &
result) {
2405 result.regions.reserve(1);
2406 Region *warpRegion =
result.addRegion();
2409 OpAsmParser::UnresolvedOperand laneId;
2421 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2428 llvm::SMLoc inputsOperandsLoc;
2429 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2430 SmallVector<Type> inputTypes;
2440 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2451 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2459void WarpExecuteOnLane0Op::getSuccessorRegions(
2460 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2467 regions.push_back(RegionSuccessor(&getWarpRegion()));
2470ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2473void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2476 build(builder,
result, resultTypes, laneId, warpSize,
2480void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2484 result.addOperands(laneId);
2485 result.addAttribute(getAttributeNames()[0],
2487 result.addTypes(resultTypes);
2488 result.addOperands(args);
2489 assert(args.size() == blockArgTypes.size());
2490 OpBuilder::InsertionGuard guard(builder);
2491 Region *warpRegion =
result.addRegion();
2493 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2502 if (expanded == distributed)
2504 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2505 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2506 if (!expandedVecType || !distributedVecType)
2507 return op->
emitOpError(
"expected vector type for distributed operands.");
2508 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2509 expandedVecType.getElementType() != distributedVecType.getElementType())
2511 "expected distributed vectors to have same rank and element type.");
2514 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2515 int64_t eDim = expandedVecType.getDimSize(i);
2516 int64_t dDim = distributedVecType.getDimSize(i);
2519 if (eDim % dDim != 0)
2521 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2522 <<
") to be a multipler of the distributed vector dimension ("
2524 scales[i] = eDim / dDim;
2526 if (llvm::product_of(scales) != warpSize)
2528 <<
"incompatible distribution dimensions from " << expandedVecType
2529 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2534LogicalResult WarpExecuteOnLane0Op::verify() {
2535 if (getArgs().size() != getWarpRegion().getNumArguments())
2537 "expected same number op arguments and block arguments.");
2538 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2540 return emitOpError(
"expected body to be terminated with 'gpu.yield'");
2541 if (yield.getNumOperands() != getNumResults())
2543 "expected same number of yield operands and return values.");
2544 int64_t warpSize = getWarpSize();
2545 for (
auto [regionArg, arg] :
2546 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2548 warpSize, getOperation())))
2551 for (
auto [yieldOperand,
result] :
2552 llvm::zip_equal(yield.getOperands(), getResults())) {
2554 warpSize, getOperation())))
2559bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2564gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2565 return cast<gpu::YieldOp>(getBody()->getTerminator());
2572void gpu::SubgroupBroadcastOp::inferResultRanges(
2573 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2574 setResultRange(getResult(), argRanges.front());
2578 switch (getBroadcastType()) {
2579 case BroadcastType::first_active_lane:
2583 case BroadcastType::specific_lane:
2587 llvm_unreachable(
"Unknown BroadcastType");
2590LogicalResult gpu::SubgroupBroadcastOp::verify() {
2591 switch (getBroadcastType()) {
2592 case BroadcastType::first_active_lane:
2595 <<
"lane can only be specified for `specific_lane` broadcast";
2597 case BroadcastType::specific_lane:
2600 <<
"lane must be specified for `specific_lane` broadcast";
2603 llvm_unreachable(
"Unknown BroadcastType");
2606OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2608 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2609 return prev.getResult();
2618KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2619 DictionaryAttr metadata) {
2620 assert(kernel &&
"invalid kernel");
2621 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2622 kernel.getAllArgAttrs(), metadata);
2627 FunctionOpInterface kernel,
2628 DictionaryAttr metadata) {
2629 assert(kernel &&
"invalid kernel");
2630 return getChecked(
emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2631 kernel.getAllArgAttrs(), metadata);
2635KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2638 NamedAttrList attrList;
2639 if (DictionaryAttr dict = getMetadata())
2642 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2648 StringAttr name, Type functionType,
2649 ArrayAttr argAttrs, DictionaryAttr metadata) {
2651 return emitError() <<
"the kernel name can't be empty";
2653 if (llvm::any_of(argAttrs, [](Attribute attr) {
2654 return !llvm::isa<DictionaryAttr>(attr);
2657 <<
"all attributes in the array must be a dictionary attribute";
2666KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2667 ArrayRef<KernelMetadataAttr> kernels,
2670 assert((!isSorted || llvm::is_sorted(kernels)) &&
2671 "expected a sorted kernel array");
2673 if (isSorted || llvm::is_sorted(kernels))
2674 return Base::get(context, kernels);
2676 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2677 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2678 return Base::get(context, kernelsTmp);
2681KernelTableAttr KernelTableAttr::getChecked(
2683 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2685 assert((!isSorted || llvm::is_sorted(kernels)) &&
2686 "expected a sorted kernel array");
2688 if (isSorted || llvm::is_sorted(kernels))
2689 return Base::getChecked(
emitError, context, kernels);
2691 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2692 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2693 return Base::getChecked(
emitError, context, kernelsTmp);
2698 ArrayRef<KernelMetadataAttr> kernels) {
2699 if (kernels.size() < 2)
2702 if (std::adjacent_find(kernels.begin(), kernels.end(),
2703 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2704 return l.getName() == r.getName();
2705 }) != kernels.end()) {
2706 return emitError() <<
"expected all kernels to be uniquely named";
2711KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2713 return found ? *iterator : KernelMetadataAttr();
2716KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2718 return found ? *iterator : KernelMetadataAttr();
2798 return CompilationTarget::Fatbin;
2801std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2803 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2804 llvm::StringSaver stringSaver(
options.first);
2810 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2811 opts.consume_front(
"\""), opts.consume_back(
"\"");
2812 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2813 opts.consume_front(
"'"), opts.consume_back(
"'");
2815 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2818 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2824std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2829std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2831 size_t startPos =
cmdOptions.find(startsWith);
2832 if (startPos == std::string::npos)
2843#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2844#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2846#define GET_ATTRDEF_CLASSES
2847#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2849#define GET_OP_CLASSES
2850#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2852#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....