36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/ErrorHandling.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/InterleavedRange.h"
42#include "llvm/Support/StringSaver.h"
49#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
55int64_t GPUBlockMappingAttr::getMappingId()
const {
56 return static_cast<int64_t>(getBlock());
59bool GPUBlockMappingAttr::isLinearMapping()
const {
60 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
63int64_t GPUBlockMappingAttr::getRelativeIndex()
const {
64 return isLinearMapping()
65 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
69int64_t GPUWarpgroupMappingAttr::getMappingId()
const {
70 return static_cast<int64_t>(getWarpgroup());
73bool GPUWarpgroupMappingAttr::isLinearMapping()
const {
74 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
77int64_t GPUWarpgroupMappingAttr::getRelativeIndex()
const {
78 return isLinearMapping()
79 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
83int64_t GPUWarpMappingAttr::getMappingId()
const {
84 return static_cast<int64_t>(getWarp());
87bool GPUWarpMappingAttr::isLinearMapping()
const {
88 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
91int64_t GPUWarpMappingAttr::getRelativeIndex()
const {
92 return isLinearMapping()
93 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
97int64_t GPUThreadMappingAttr::getMappingId()
const {
98 return static_cast<int64_t>(getThread());
101bool GPUThreadMappingAttr::isLinearMapping()
const {
102 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
105int64_t GPUThreadMappingAttr::getRelativeIndex()
const {
106 return isLinearMapping()
107 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
111int64_t GPULaneMappingAttr::getMappingId()
const {
112 return static_cast<int64_t>(getLane());
115bool GPULaneMappingAttr::isLinearMapping()
const {
116 return getMappingId() >=
static_cast<int64_t>(MappingId::LinearDim0);
119int64_t GPULaneMappingAttr::getRelativeIndex()
const {
120 return isLinearMapping()
121 ? getMappingId() -
static_cast<int64_t>(MappingId::LinearDim0)
125int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds()
const {
return 64; }
137Value GPUMappingMaskAttr::createLogicalLinearMappingId(
141 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
142 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
143 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
144 filter = arith::SubIOp::create(
b, loc, filter, one);
145 Value filteredId = arith::AndIOp::create(
b, loc, mask, filter);
146 return math::CtPopOp::create(
b, loc, filteredId);
159Value GPUMappingMaskAttr::createIsActiveIdPredicate(
163 arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(getMask()));
164 Value one = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(1));
165 Value filter = arith::ShLIOp::create(
b, loc, one, physicalLinearMappingId);
166 Value filtered = arith::AndIOp::create(
b, loc, mask, filter);
167 Value zero = arith::ConstantOp::create(
b, loc,
b.getI64IntegerAttr(0));
168 return arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::ne, filtered,
172int64_t GPUMemorySpaceMappingAttr::getMappingId()
const {
173 return static_cast<int64_t>(getAddressSpace());
176bool GPUMemorySpaceMappingAttr::isLinearMapping()
const {
177 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support linear mapping");
180int64_t GPUMemorySpaceMappingAttr::getRelativeIndex()
const {
181 llvm_unreachable(
"GPUMemorySpaceMappingAttr does not support relative index");
198 elementType, operand);
212 return elementType.
isF16() || elementType.
isF32() || elementType.
isF64() ||
221 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
222 return emitError() <<
"operand expected to be one of AOp, BOp or COp";
224 if (
shape.size() != 2)
225 return emitError() <<
"MMAMatrixType must have exactly two dimensions";
229 <<
"MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
238bool GPUDialect::isWorkgroupMemoryAddressSpace(
Attribute memorySpace) {
241 if (
auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
242 return gpuAttr.getValue() == getWorkgroupAddressSpace();
246bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
247 Attribute memorySpace = type.getMemorySpace();
248 return isWorkgroupMemoryAddressSpace(memorySpace);
251bool GPUDialect::isKernel(
Operation *op) {
252 UnitAttr isKernelAttr = op->
getAttrOfType<UnitAttr>(getKernelFuncAttrName());
253 return static_cast<bool>(isKernelAttr);
259struct GPUInlinerInterface :
public DialectInlinerInterface {
260 using DialectInlinerInterface::DialectInlinerInterface;
263 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
269void GPUDialect::initialize() {
270 addTypes<AsyncTokenType>();
271 addTypes<MMAMatrixType>();
272 addTypes<SparseDnTensorHandleType>();
273 addTypes<SparseSpMatHandleType>();
274 addTypes<SparseSpGEMMOpHandleType>();
277#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
280#define GET_ATTRDEF_LIST
281#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
283 addInterfaces<GPUInlinerInterface>();
284 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
286 declarePromisedInterfaces<ValueBoundsOpInterface, ClusterDimOp,
287 ClusterDimBlocksOp, ClusterIdOp, ClusterBlockIdOp,
288 BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp,
289 LaneIdOp, SubgroupIdOp, GlobalIdOp, NumSubgroupsOp,
290 SubgroupSizeOp, LaunchOp, SubgroupBroadcastOp>();
296 return "sparse.dntensor_handle";
298 return "sparse.spmat_handle";
300 return "sparse.spgemmop_handle";
302 llvm_unreachable(
"unknown sparse handle kind");
306Type GPUDialect::parseType(DialectAsmParser &parser)
const {
314 if (keyword ==
"async.token")
317 if (keyword ==
"mma_matrix") {
325 SmallVector<int64_t> shape;
346 shape, elementType, operand);
361void GPUDialect::printType(Type type, DialectAsmPrinter &os)
const {
364 .Case<SparseDnTensorHandleType>([&](Type) {
367 .Case<SparseSpMatHandleType>(
369 .Case<SparseSpGEMMOpHandleType>([&](Type) {
375 for (
auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
378 os <<
", \"" << fragTy.
getOperand() <<
"\"" <<
'>';
380 .DefaultUnreachable(
"unexpected 'gpu' type kind");
385 auto array = dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
388 " must be a dense i32 array");
389 if (array.size() != 3)
391 " must contain exactly 3 elements");
395LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
396 NamedAttribute attr) {
397 if (attr.
getName() == getKnownBlockSizeAttrHelper().getName())
399 if (attr.
getName() == getKnownGridSizeAttrHelper().getName())
401 if (attr.
getName() == getKnownClusterSizeAttrHelper().getName())
403 if (!llvm::isa<UnitAttr>(attr.
getValue()) ||
404 attr.
getName() != getContainerModuleAttrName())
407 auto module = dyn_cast<ModuleOp>(op);
410 << getContainerModuleAttrName() <<
"' attribute to be attached to '"
411 << ModuleOp::getOperationName() <<
'\'';
425 return parser.
emitError(loc,
"needs to be named when marked 'async'");
440 if (asyncDependencies.empty())
444 printer << llvm::interleaved_array(asyncDependencies);
472 p <<
' ' << keyword <<
'(';
473 llvm::interleaveComma(
474 llvm::enumerate(values), p, [&p, attributes](
auto pair) {
475 BlockArgument v = pair.value();
476 p << v <<
" : " << v.
getType();
478 size_t attributionIndex = pair.index();
479 DictionaryAttr attrs;
480 if (attributes && attributionIndex < attributes.size())
481 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
491 gpu::AddressSpace memorySpace) {
492 for (
Value v : attributions) {
493 auto type = llvm::dyn_cast<MemRefType>(v.
getType());
495 return op->
emitOpError() <<
"expected memref type in attribution";
500 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
503 if (addressSpace.getValue() != memorySpace)
505 <<
"expected memory space " << stringifyAddressSpace(memorySpace)
506 <<
" in attribution";
517 using Kind = gpu::AllReduceOperation;
518 if (llvm::is_contained(
519 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
521 if (!isa<FloatType>(resType))
525 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
526 Kind::AND, Kind::OR, Kind::XOR},
528 if (!isa<IntegerType>(resType))
535LogicalResult gpu::AllReduceOp::verifyRegions() {
536 if (getBody().empty() != getOp().has_value())
537 return emitError(
"expected either an op attribute or a non-empty body");
538 if (!getBody().empty()) {
539 if (getBody().getNumArguments() != 2)
540 return emitError(
"expected two region arguments");
541 for (
auto argument : getBody().getArguments()) {
542 if (argument.getType() !=
getType())
543 return emitError(
"incorrect region argument type");
545 unsigned yieldCount = 0;
546 for (
Block &block : getBody()) {
547 if (
auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
548 if (yield.getNumOperands() != 1)
549 return emitError(
"expected one gpu.yield operand");
550 if (yield.getOperand(0).getType() !=
getType())
551 return emitError(
"incorrect gpu.yield type");
556 return emitError(
"expected gpu.yield op in region");
558 gpu::AllReduceOperation opName = *getOp();
560 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
561 <<
"` reduction operation is not compatible with type "
570 auto launchOp = dyn_cast<gpu::LaunchOp>(op->
getParentOp());
574 Region &body = launchOp.getBody();
575 assert(!body.
empty() &&
"Invalid region");
581OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor ) {
592 AllReduceOperationAttr &attr) {
595 std::optional<AllReduceOperation> op =
596 gpu::symbolizeAllReduceOperation(enumStr);
599 attr = AllReduceOperationAttr::get(parser.
getContext(), *op);
605 AllReduceOperationAttr attr) {
614LogicalResult gpu::SubgroupReduceOp::verify() {
616 if (
auto vecTy = dyn_cast<VectorType>(elemType)) {
617 if (vecTy.isScalable())
618 return emitOpError() <<
"is not compatible with scalable vector types";
620 elemType = vecTy.getElementType();
623 gpu::AllReduceOperation opName = getOp();
625 return emitError() <<
'`' << gpu::stringifyAllReduceOperation(opName)
626 <<
"` reduction operation is not compatible with type "
630 auto clusterSize = getClusterSize();
632 uint32_t size = *clusterSize;
633 if (!llvm::isPowerOf2_32(size)) {
635 <<
" is not a power of two";
639 uint32_t stride = getClusterStride();
640 if (stride != 1 && !clusterSize) {
641 return emitOpError() <<
"cluster stride can only be specified if cluster "
644 if (!llvm::isPowerOf2_32(stride)) {
645 return emitOpError() <<
"cluster stride " << stride
646 <<
" is not a power of two";
652OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor ) {
653 if (getClusterSize() == 1)
670 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
674 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
692 Value getBlockSizeZ,
Value dynamicSharedMemorySize,
702 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
706 result.addOperands(asyncDependencies);
711 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
712 getBlockSizeY, getBlockSizeZ});
714 result.addOperands(clusterSizeX);
716 result.addOperands(clusterSizeY);
718 result.addOperands(clusterSizeZ);
719 if (dynamicSharedMemorySize)
720 result.addOperands(dynamicSharedMemorySize);
724 result.addAttribute(getModuleAttrName(
result.name), module);
726 result.addAttribute(getFunctionAttrName(
result.name), function);
734 for (
unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
737 for (
Type argTy : workgroupAttributions)
739 for (
Type argTy : privateAttributions)
743 segmentSizes.front() = asyncDependencies.size();
744 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
745 segmentSizes[7] = clusterSizeX ? 1 : 0;
746 segmentSizes[8] = clusterSizeY ? 1 : 0;
747 segmentSizes[9] = clusterSizeZ ? 1 : 0;
748 result.addAttribute(getOperandSegmentSizeAttr(),
753 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
754 auto args = getBody().getArguments();
759 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
760 auto args = getBody().getArguments();
765 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
766 auto args = getBody().getArguments();
771 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
772 auto args = getBody().getArguments();
773 return KernelDim3{args[9], args[10], args[11]};
776std::optional<KernelDim3> LaunchOp::getClusterIds() {
777 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
778 if (!hasClusterSize())
780 auto args = getBody().getArguments();
781 return KernelDim3{args[12], args[13], args[14]};
784std::optional<KernelDim3> LaunchOp::getClusterSize() {
785 assert(!getBody().empty() &&
"LaunchOp body must not be empty.");
786 if (!hasClusterSize())
788 auto args = getBody().getArguments();
789 return KernelDim3{args[15], args[16], args[17]};
792KernelDim3 LaunchOp::getGridSizeOperandValues() {
793 auto operands = getOperands().drop_front(getAsyncDependencies().size());
794 return KernelDim3{operands[0], operands[1], operands[2]};
797KernelDim3 LaunchOp::getBlockSizeOperandValues() {
798 auto operands = getOperands().drop_front(getAsyncDependencies().size());
799 return KernelDim3{operands[3], operands[4], operands[5]};
802std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
803 auto operands = getOperands().drop_front(getAsyncDependencies().size());
804 if (!hasClusterSize())
806 return KernelDim3{operands[6], operands[7], operands[8]};
809LogicalResult LaunchOp::verify() {
810 if (!(hasClusterSize()) &&
811 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
812 return emitOpError() <<
"cluster size must be all present";
816LogicalResult LaunchOp::verifyRegions() {
820 if (getBody().empty()) {
823 if (getBody().getNumArguments() <
824 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) {
825 return emitOpError(
"unexpected number of region arguments");
830 GPUDialect::getWorkgroupAddressSpace())) ||
832 GPUDialect::getPrivateAddressSpace())))
837 for (
Block &block : getBody()) {
840 if (block.back().getNumSuccessors() != 0)
842 if (!isa<gpu::TerminatorOp>(&block.back())) {
845 .append(
"expected '", gpu::TerminatorOp::getOperationName(),
846 "' or a terminator with successors")
847 .attachNote(getLoc())
848 .append(
"in '", LaunchOp::getOperationName(),
"' body region");
852 if (getNumResults() == 0 && getAsyncToken())
853 return emitOpError(
"needs to be named when async keyword is specified");
864 p <<
'(' << ids.
x <<
", " << ids.
y <<
", " << ids.
z <<
") in (";
865 p << size.
x <<
" = " << operands.
x <<
", ";
866 p << size.
y <<
" = " << operands.
y <<
", ";
867 p << size.
z <<
" = " << operands.
z <<
')';
870void LaunchOp::print(OpAsmPrinter &p) {
871 if (getAsyncToken()) {
873 if (!getAsyncDependencies().empty())
874 p <<
" [" << getAsyncDependencies() <<
']';
877 if (hasClusterSize()) {
878 p <<
' ' << getClustersKeyword();
880 getClusterSizeOperandValues().value(),
881 getClusterIds().value());
883 p <<
' ' << getBlocksKeyword();
886 p <<
' ' << getThreadsKeyword();
889 if (getDynamicSharedMemorySize())
890 p <<
' ' << getDynamicSharedMemorySizeKeyword() <<
' '
891 << getDynamicSharedMemorySize();
894 StringRef moduleAttrName = getModuleAttrName();
895 if (
auto module = getModule()) {
896 p <<
' ' << moduleAttrName <<
'(';
901 StringRef functionAttrName = getFunctionAttrName();
902 if (
auto function = getFunction()) {
903 p <<
' ' << functionAttrName <<
'(';
915 LaunchOp::getOperandSegmentSizeAttr(),
916 getNumWorkgroupAttributionsAttrName(),
917 moduleAttrName, functionAttrName});
932 assert(
indices.size() == 3 &&
"space for three indices expected");
939 if (args.size() != 3) {
941 << keyword <<
" expects 3 arguments, but got " << args.size();
943 std::move(args.begin(), args.end(),
indices.begin());
945 for (
int i = 0; i < 3; ++i) {
967ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &
result) {
969 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
970 sizes(LaunchOp::kNumConfigOperands);
973 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
974 LaunchOp::kNumConfigRegionAttributes);
977 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
988 "gpu.launch requires 'async' keyword to return a value");
989 result.types.push_back(asyncTokenType);
992 bool hasCluster =
false;
996 regionArgs.resize(18);
998 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
999 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1005 parser, sizesRef.drop_front(6), regionArgsRef.slice(15, 3),
1006 regionArgsRef.slice(12, 3), LaunchOp::getClustersKeyword()))
1014 if (parser.
parseKeyword(LaunchOp::getBlocksKeyword()) ||
1016 regionArgsRef.slice(6, 3), regionArgsRef.slice(0, 3),
1017 LaunchOp::getBlocksKeyword()) ||
1020 regionArgsRef.slice(9, 3), regionArgsRef.slice(3, 3),
1021 LaunchOp::getThreadsKeyword()) ||
1026 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1027 bool hasDynamicSharedMemorySize =
false;
1029 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1030 hasDynamicSharedMemorySize =
true;
1039 StringRef moduleAttrName = getModuleAttrName(
result.name);
1041 FlatSymbolRefAttr moduleSymbol;
1049 StringRef functionAttrName = getFunctionAttrName(
result.name);
1051 FlatSymbolRefAttr funcSymbol;
1066 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1067 LaunchOp::kNumConfigRegionAttributes + 6, index);
1069 SmallVector<OpAsmParser::Argument> regionArguments;
1070 for (
auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1071 OpAsmParser::Argument arg;
1072 arg.
ssaName = std::get<0>(ssaValueAndType);
1073 arg.
type = std::get<1>(ssaValueAndType);
1074 regionArguments.push_back(arg);
1085 unsigned numWorkgroupAttrs = regionArguments.size() -
1086 LaunchOp::kNumConfigRegionAttributes -
1087 (hasCluster ? 6 : 0);
1088 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1099 Region *body =
result.addRegion();
1104 SmallVector<int32_t, 11> segmentSizes(11, 1);
1105 segmentSizes.front() = asyncDependencies.size();
1108 segmentSizes[7] = 0;
1109 segmentSizes[8] = 0;
1110 segmentSizes[9] = 0;
1112 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1113 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1127 bool simplified =
false;
1128 auto constPropIdUses = [&](
Value id,
Value size) {
1132 if (
id.getUses().empty())
1144 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1145 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1146 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1147 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1148 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1149 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1155void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1156 MLIRContext *context) {
1157 rewrites.
add<FoldLaunchArguments>(context);
1162BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1163 auto attrName = getNumWorkgroupAttributionsAttrName();
1164 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1165 (*this)->setAttr(attrName,
1166 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1167 return getBody().insertArgument(
1168 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1173BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1176 return getBody().addArgument(type, loc);
1183void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1184 SymbolRefAttr kernelSymbol,
KernelDim3 gridSize,
1185 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1186 ValueRange kernelOperands, Type asyncTokenType,
1188 std::optional<KernelDim3> clusterSize) {
1189 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1190 "expected a symbol reference with a single nested reference");
1191 result.addOperands(asyncDependencies);
1198 if (clusterSize.has_value())
1199 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1200 if (dynamicSharedMemorySize)
1201 result.addOperands(dynamicSharedMemorySize);
1202 result.addOperands(kernelOperands);
1204 Properties &prop =
result.getOrAddProperties<Properties>();
1205 prop.kernel = kernelSymbol;
1206 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1208 llvm::fill(prop.operandSegmentSizes, 1);
1209 prop.operandSegmentSizes[0] = asyncDependencies.size();
1210 if (!clusterSize.has_value()) {
1211 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1212 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1213 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1215 prop.operandSegmentSizes[segmentSizesLen - 3] =
1216 dynamicSharedMemorySize ? 1 : 0;
1217 prop.operandSegmentSizes[segmentSizesLen - 2] =
1218 static_cast<int32_t
>(kernelOperands.size());
1219 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1222void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1224 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1225 ValueRange kernelOperands, Type asyncTokenType,
1227 std::optional<KernelDim3> clusterSize) {
1228 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1230 SymbolRefAttr::get(kernelModule.getNameAttr(),
1231 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1232 build(builder,
result, kernelSymbol, gridSize, getBlockSize,
1233 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1234 asyncDependencies, clusterSize);
1237void LaunchFuncOp::build(OpBuilder &builder, OperationState &
result,
1239 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1240 ValueRange kernelOperands, Value asyncObject,
1241 std::optional<KernelDim3> clusterSize) {
1245 if (clusterSize.has_value())
1246 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1247 if (dynamicSharedMemorySize)
1248 result.addOperands(dynamicSharedMemorySize);
1249 result.addOperands(kernelOperands);
1251 result.addOperands(asyncObject);
1252 Properties &prop =
result.getOrAddProperties<Properties>();
1253 prop.kernel = kernel;
1254 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1256 llvm::fill(prop.operandSegmentSizes, 1);
1257 prop.operandSegmentSizes[0] = 0;
1258 if (!clusterSize.has_value()) {
1259 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1260 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1261 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1263 prop.operandSegmentSizes[segmentSizesLen - 3] =
1264 dynamicSharedMemorySize ? 1 : 0;
1265 prop.operandSegmentSizes[segmentSizesLen - 2] =
1266 static_cast<int32_t
>(kernelOperands.size());
1267 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1270StringAttr LaunchFuncOp::getKernelModuleName() {
1274StringAttr LaunchFuncOp::getKernelName() {
1278unsigned LaunchFuncOp::getNumKernelOperands() {
1279 return getKernelOperands().size();
1282Value LaunchFuncOp::getKernelOperand(
unsigned i) {
1283 return getKernelOperands()[i];
1286KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1287 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1288 return KernelDim3{operands[0], operands[1], operands[2]};
1291KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1292 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1293 return KernelDim3{operands[3], operands[4], operands[5]};
1296KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1297 assert(hasClusterSize() &&
1298 "cluster size is not set, check hasClusterSize() first");
1299 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1300 return KernelDim3{operands[6], operands[7], operands[8]};
1303LogicalResult LaunchFuncOp::verify() {
1304 auto module = (*this)->getParentOfType<ModuleOp>();
1306 return emitOpError(
"expected to belong to a module");
1308 if (!module->getAttrOfType<UnitAttr>(
1309 GPUDialect::getContainerModuleAttrName()))
1310 return emitOpError(
"expected the closest surrounding module to have the '" +
1311 GPUDialect::getContainerModuleAttrName() +
1314 if (hasClusterSize()) {
1315 if (getClusterSizeY().
getType() != getClusterSizeX().
getType() ||
1318 <<
"expects types of the cluster dimensions must be the same";
1325LaunchFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1326 LaunchFuncOp launchOp = *
this;
1329 if (isa<GPUModuleOp>(table))
1334 if (!launchOp->getParentOp() ||
1335 launchOp->getParentOp()->getParentOp() != table)
1340 if (!launchOp->getAttrOfType<SymbolRefAttr>(
1341 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
1345 StringAttr kernelContainerName = launchOp.getKernelModuleName();
1346 Operation *kernelContainer =
1348 if (!kernelContainer)
1350 <<
"kernel container '" << kernelContainerName.getValue()
1351 <<
"' is undefined";
1354 if (isa<BinaryOp>(kernelContainer))
1357 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
1359 return launchOp.emitOpError()
1360 <<
"kernel module '" << kernelContainerName.getValue()
1361 <<
"' is undefined";
1365 kernelModule, launchOp.getKernelName());
1368 << launchOp.getKernel() <<
"' is undefined";
1369 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
1370 if (!kernelConvertedFunction) {
1371 InFlightDiagnostic
diag = launchOp.emitOpError()
1372 <<
"referenced kernel '" << launchOp.getKernel()
1373 <<
"' is not a function";
1374 diag.attachNote(kernelFunc->
getLoc()) <<
"see the kernel definition here";
1379 GPUDialect::getKernelFuncAttrName()))
1380 return launchOp.emitOpError(
"kernel function is missing the '")
1381 << GPUDialect::getKernelFuncAttrName() <<
"' attribute";
1386 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
1387 if (!kernelGPUFunction)
1390 unsigned actualNumArguments = launchOp.getNumKernelOperands();
1391 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
1392 if (expectedNumArguments != actualNumArguments)
1393 return launchOp.emitOpError(
"got ")
1394 << actualNumArguments <<
" kernel operands but expected "
1395 << expectedNumArguments;
1397 FunctionType functionType = kernelGPUFunction.getFunctionType();
1398 for (
unsigned i = 0; i < expectedNumArguments; ++i) {
1399 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
1400 return launchOp.emitOpError(
"type of function argument ")
1401 << i <<
" does not match";
1410 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1411 Type &clusterXTy,
Type &clusterYTy,
Type &clusterZTy) {
1418 if (clusterValue.has_value()) {
1419 clusterXTy = clusterYTy = clusterZTy = dimTy;
1426 Type clusterYTy,
Type clusterZTy) {
1428 printer <<
": " << dimTy;
1438 auto parseElement = [&]() -> ParseResult {
1439 return failure(parser.
parseOperand(argNames.emplace_back()) ||
1444 parseElement,
" in argument list");
1449 if (operands.empty())
1452 llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1453 [&](
const auto &pair) {
1454 auto [operand, type] = pair;
1455 printer << operand <<
" : " << type;
1464void ShuffleOp::build(OpBuilder &builder, OperationState &
result, Value value,
1465 int32_t offset, int32_t width, ShuffleMode mode) {
1466 build(builder,
result, value,
1467 arith::ConstantOp::create(builder,
result.location,
1469 arith::ConstantOp::create(builder,
result.location,
1478LogicalResult RotateOp::verify() {
1479 uint32_t offset = getOffset();
1480 uint32_t width = getWidth();
1482 if (offset >= width) {
1483 return emitOpError() <<
"offset must be in the range [0, " << width <<
")";
1496 auto nextOp = dyn_cast_or_null<BarrierOp>(op->getNextNode());
1500 std::optional<ArrayAttr> thisMemfence = op.getAddressSpaces();
1501 std::optional<ArrayAttr> nextMemfence = nextOp.getAddressSpaces();
1505 if (!nextMemfence) {
1506 op.removeAddressSpacesAttr();
1510 if (*thisMemfence == *nextMemfence) {
1514 llvm::SmallSetVector<Attribute, 4> mergedSpaces;
1516 mergedSpaces.insert(attr);
1518 mergedSpaces.insert(attr);
1519 op.setAddressSpacesAttr(rewriter.
getArrayAttr(mergedSpaces.takeVector()));
1527void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1528 MLIRContext *context) {
1532void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1533 mlir::OperationState &odsState,
1534 std::optional<AddressSpace> addressSpace) {
1538 AddressSpaceAttr::get(odsBuilder.
getContext(), addressSpace.value()));
1539 build(odsBuilder, odsState, addressSpacesAttr);
1546void BarrierOp::build(OpBuilder &builder, OperationState &odsState,
1547 Value memrefToFence) {
1548 std::optional<AddressSpace> addrSpaceToFence;
1549 if (
auto memrefType = dyn_cast<BaseMemRefType>(memrefToFence.
getType()))
1550 if (
auto addrSpaceAttr = dyn_cast_if_present<gpu::AddressSpaceAttr>(
1551 memrefType.getMemorySpace()))
1552 addrSpaceToFence = addrSpaceAttr.getValue();
1553 return build(builder, odsState, addrSpaceToFence);
1562BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1563 auto attrName = getNumWorkgroupAttributionsAttrName();
1564 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1565 (*this)->setAttr(attrName,
1566 IntegerAttr::get(attr.getType(), attr.
getValue() + 1));
1567 return getBody().insertArgument(
1568 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1573BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1576 return getBody().addArgument(type, loc);
1579void GPUFuncOp::build(OpBuilder &builder, OperationState &
result,
1580 StringRef name, FunctionType type,
1583 ArrayRef<NamedAttribute> attrs) {
1584 OpBuilder::InsertionGuard g(builder);
1588 result.addAttribute(getFunctionTypeAttrName(
result.name),
1589 TypeAttr::get(type));
1590 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1592 result.addAttributes(attrs);
1593 Region *body =
result.addRegion();
1597 for (Type argTy : type.getInputs())
1599 for (Type argTy : workgroupAttributions)
1601 for (Type argTy : privateAttributions)
1620 size_t existingArgs = args.size();
1627 bool hadAttrs = llvm::any_of(
ArrayRef(args).drop_front(existingArgs),
1632 attributionAttrs =
nullptr;
1638 for (
const auto &argument :
ArrayRef(args).drop_front(existingArgs)) {
1639 if (!argument.attrs)
1642 attributionAttrsVec.push_back(argument.attrs);
1644 attributionAttrs = builder.
getArrayAttr(attributionAttrsVec);
1653ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &
result) {
1654 SmallVector<OpAsmParser::Argument> entryArgs;
1655 SmallVector<DictionaryAttr> resultAttrs;
1656 SmallVector<Type> resultTypes;
1660 StringAttr nameAttr;
1667 parser,
false, entryArgs, isVariadic, resultTypes,
1671 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1672 return parser.
emitError(signatureLocation)
1673 <<
"gpu.func requires named arguments";
1679 SmallVector<Type> argTypes;
1680 for (
auto &arg : entryArgs)
1681 argTypes.push_back(arg.
type);
1683 result.addAttribute(getFunctionTypeAttrName(
result.name),
1684 TypeAttr::get(type));
1687 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
1688 getResAttrsAttrName(
result.name));
1690 Attribute workgroupAttributionAttrs;
1693 entryArgs, workgroupAttributionAttrs)))
1698 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1699 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1701 if (workgroupAttributionAttrs)
1702 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(
result.name),
1703 workgroupAttributionAttrs);
1705 Attribute privateAttributionAttrs;
1708 entryArgs, privateAttributionAttrs)))
1710 if (privateAttributionAttrs)
1711 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(
result.name),
1712 privateAttributionAttrs);
1716 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1725 auto *body =
result.addRegion();
1729void GPUFuncOp::print(OpAsmPrinter &p) {
1733 FunctionType type = getFunctionType();
1739 getWorkgroupAttribAttrs().value_or(
nullptr));
1741 getPrivateAttribAttrs().value_or(
nullptr));
1743 p <<
' ' << getKernelKeyword();
1747 {getNumWorkgroupAttributionsAttrName(),
1748 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1749 getArgAttrsAttrName(), getResAttrsAttrName(),
1750 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1756 StringAttr attrName) {
1757 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1758 if (!allAttrs ||
index >= allAttrs.size())
1759 return DictionaryAttr();
1760 return llvm::cast<DictionaryAttr>(allAttrs[
index]);
1763DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(
unsigned index) {
1767DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(
unsigned index) {
1772 DictionaryAttr value, StringAttr attrName) {
1774 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1777 elements.append(allAttrs.begin(), allAttrs.end());
1778 while (elements.size() <=
index)
1779 elements.push_back(DictionaryAttr::get(ctx));
1781 elements[
index] = DictionaryAttr::get(ctx);
1783 elements[
index] = value;
1784 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1785 op->setAttr(attrName, newValue);
1788void GPUFuncOp::setworkgroupAttributionAttrs(
unsigned index,
1789 DictionaryAttr value) {
1793void GPUFuncOp::setPrivateAttributionAttrs(
unsigned int index,
1794 DictionaryAttr value) {
1799 StringAttr name, StringAttr attrsName) {
1803 return dict.get(name);
1806Attribute GPUFuncOp::getWorkgroupAttributionAttr(
unsigned index,
1808 assert(index < getNumWorkgroupAttributions() &&
1809 "index must map to a workgroup attribution");
1811 getWorkgroupAttribAttrsAttrName());
1814Attribute GPUFuncOp::getPrivateAttributionAttr(
unsigned index,
1816 assert(index < getNumPrivateAttributions() &&
1817 "index must map to a private attribution");
1819 getPrivateAttribAttrsAttrName());
1823 Attribute value, StringAttr attrsName) {
1828 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1831 bool mustSort =
true;
1832 for (
unsigned i = 0, e = elems.size(); i < e; ++i) {
1833 if (elems[i].getName() == name) {
1836 std::swap(elems[i], elems[elems.size() - 1]);
1848 elems.emplace_back(name, value);
1851 DictionaryAttr::sortInPlace(elems);
1853 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1857void GPUFuncOp::setWorkgroupAttributionAttr(
unsigned index, StringAttr name,
1859 assert(index < getNumWorkgroupAttributions() &&
1860 "index must map to a workgroup attribution");
1862 getWorkgroupAttribAttrsAttrName());
1865void GPUFuncOp::setPrivateAttributionAttr(
unsigned index, StringAttr name,
1867 assert(index < getNumPrivateAttributions() &&
1868 "index must map to a private attribution");
1870 getPrivateAttribAttrsAttrName());
1873LogicalResult GPUFuncOp::verifyType() {
1874 if (isKernel() && getFunctionType().getNumResults() != 0)
1875 return emitOpError() <<
"expected void return type for kernel function";
1881LogicalResult GPUFuncOp::verifyBody() {
1883 return emitOpError() <<
"expected body with at least one block";
1884 unsigned numFuncArguments = getNumArguments();
1885 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1886 unsigned numBlockArguments = front().getNumArguments();
1887 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1889 << numFuncArguments + numWorkgroupAttributions
1890 <<
" arguments to body region";
1892 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1893 for (
unsigned i = 0; i < numFuncArguments; ++i) {
1894 Type blockArgType = front().getArgument(i).getType();
1895 if (funcArgTypes[i] != blockArgType)
1896 return emitOpError() <<
"expected body region argument #" << i
1897 <<
" to be of type " << funcArgTypes[i] <<
", got "
1902 GPUDialect::getWorkgroupAddressSpace())) ||
1904 GPUDialect::getPrivateAddressSpace())))
1914LogicalResult gpu::ReturnOp::verify() {
1915 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1917 FunctionType funType = function.getFunctionType();
1919 if (funType.getNumResults() != getOperands().size())
1921 .append(
"expected ", funType.getNumResults(),
" result operands")
1922 .attachNote(function.getLoc())
1923 .append(
"return type declared here");
1925 for (
const auto &pair : llvm::enumerate(
1926 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1927 auto [type, operand] = pair.value();
1928 if (type != operand.getType())
1929 return emitOpError() <<
"unexpected type `" << operand.getType()
1930 <<
"' for operand #" << pair.index();
1939void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1941 Attribute offloadingHandler) {
1942 result.addRegion()->emplaceBlock();
1943 Properties &props =
result.getOrAddProperties<Properties>();
1945 props.targets = targets;
1947 props.offloadingHandler = offloadingHandler;
1950void GPUModuleOp::build(OpBuilder &builder, OperationState &
result,
1951 StringRef name, ArrayRef<Attribute> targets,
1952 Attribute offloadingHandler) {
1953 build(builder,
result, name,
1958bool GPUModuleOp::hasTarget(Attribute
target) {
1959 if (
ArrayAttr targets = getTargetsAttr())
1960 return llvm::count(targets.getValue(),
target);
1964void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1965 ArrayAttr &targetsAttr = getProperties().targets;
1966 SmallVector<Attribute> targetsVector(targets);
1967 targetsAttr = ArrayAttr::get(
getContext(), targetsVector);
1970LogicalResult GPUModuleOp::verify() {
1971 auto targets = getOperation()->getAttrOfType<
ArrayAttr>(
"targets");
1976 for (
auto target : targets) {
1977 if (
auto verifyTargetAttr =
1978 llvm::dyn_cast<TargetAttrVerifyInterface>(
target)) {
1979 if (verifyTargetAttr.verifyTarget(getOperation()).
failed())
1989void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
1990 Attribute offloadingHandler,
ArrayAttr objects) {
1991 auto &properties =
result.getOrAddProperties<Properties>();
1994 properties.objects = objects;
1995 if (offloadingHandler)
1996 properties.offloadingHandler = offloadingHandler;
1998 properties.offloadingHandler = builder.
getAttr<SelectObjectAttr>(
nullptr);
2001void BinaryOp::build(OpBuilder &builder, OperationState &
result, StringRef name,
2002 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
2003 build(builder,
result, name, offloadingHandler,
2015 if (!offloadingHandler)
2022 if (offloadingHandler != SelectObjectAttr::get(op->
getContext(),
nullptr))
2023 printer <<
'<' << offloadingHandler <<
'>';
2030LogicalResult MemcpyOp::verify() {
2031 auto srcType = getSrc().getType();
2032 auto dstType = getDst().getType();
2035 return emitOpError(
"arguments have incompatible element type");
2038 return emitOpError(
"arguments have incompatible shape");
2047struct EraseTrivialCopyOp :
public OpRewritePattern<MemcpyOp> {
2048 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
2050 LogicalResult matchAndRewrite(MemcpyOp op,
2051 PatternRewriter &rewriter)
const override {
2052 Value dest = op.getDst();
2061 if (llvm::any_of(dest.
getUsers(), [op, dest](Operation *user) {
2062 return user != op &&
2063 !hasSingleEffect<MemoryEffects::Free>(user, dest);
2069 if (op.getAsyncDependencies().size() > 1 ||
2070 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2071 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2073 rewriter.
replaceOp(op, op.getAsyncDependencies());
2080void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2081 MLIRContext *context) {
2082 results.
add<EraseTrivialCopyOp>(context);
2089LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2090 auto srcType = getSrcMemref().getType();
2091 auto resType = getRes().getType();
2092 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2093 auto operand = resMatrixType.getOperand();
2094 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2096 if (!srcMemrefType.isLastDimUnitStride())
2098 "expected source memref most minor dim must have unit stride");
2100 if (operand !=
"AOp" && operand !=
"BOp" && operand !=
"COp")
2101 return emitError(
"only AOp, BOp and COp can be loaded");
2110LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2111 auto srcType = getSrc().getType();
2112 auto dstType = getDstMemref().getType();
2113 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2114 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2116 if (!dstMemrefType.isLastDimUnitStride())
2118 "expected destination memref most minor dim must have unit stride");
2120 if (srcMatrixType.getOperand() !=
"COp")
2122 "expected the operand matrix being stored to have 'COp' operand type");
2131LogicalResult SubgroupMmaComputeOp::verify() {
2132 enum OperandMap {
A,
B,
C };
2133 SmallVector<MMAMatrixType, 3> opTypes;
2134 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().
getType()));
2135 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().
getType()));
2136 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().
getType()));
2138 if (opTypes[A].getOperand() !=
"AOp" || opTypes[B].getOperand() !=
"BOp" ||
2139 opTypes[C].getOperand() !=
"COp")
2140 return emitError(
"operands must be in the order AOp, BOp, COp");
2142 ArrayRef<int64_t> aShape, bShape, cShape;
2143 aShape = opTypes[
A].getShape();
2144 bShape = opTypes[
B].getShape();
2145 cShape = opTypes[
C].getShape();
2147 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2148 bShape[1] != cShape[1])
2149 return emitError(
"operand shapes do not satisfy matmul constraints");
2154LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2155 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2159LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2160 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2173struct EraseRedundantGpuWaitOpPairs :
public OpRewritePattern<WaitOp> {
2177 LogicalResult matchAndRewrite(WaitOp op,
2178 PatternRewriter &rewriter)
const final {
2179 auto predicate = [](Value value) {
2180 auto waitOp = value.getDefiningOp<WaitOp>();
2181 return waitOp && waitOp->getNumOperands() == 0;
2183 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2185 SmallVector<Value> validOperands;
2186 for (Value operand : op->getOperands()) {
2187 if (predicate(operand))
2189 validOperands.push_back(operand);
2191 rewriter.
modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2203struct SimplifyGpuWaitOp :
public OpRewritePattern<WaitOp> {
2207 LogicalResult matchAndRewrite(WaitOp op,
2208 PatternRewriter &rewriter)
const final {
2211 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2216 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2217 op.getAsyncToken()) {
2218 rewriter.
replaceOp(op, op.getAsyncDependencies());
2222 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2232void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2233 MLIRContext *context) {
2234 results.
add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2241LogicalResult AllocOp::verify() {
2242 auto memRefType = llvm::cast<MemRefType>(getMemref().
getType());
2248 unsigned numSymbols = 0;
2249 if (!memRefType.getLayout().isIdentity())
2250 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2251 if (getSymbolOperands().size() != numSymbols) {
2253 "symbol operand count does not equal memref symbol count");
2263struct SimplifyDimOfAllocOp :
public OpRewritePattern<memref::DimOp> {
2264 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2266 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2267 PatternRewriter &rewriter)
const override {
2268 std::optional<int64_t> index = dimOp.getConstantIndex();
2272 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2273 if (!memrefType || index.value() >= memrefType.getRank() ||
2274 !memrefType.isDynamicDim(index.value()))
2277 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2281 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2282 memrefType.getDynamicDimIndex(index.value()));
2283 rewriter.
replaceOp(dimOp, substituteOp);
2290void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2291 MLIRContext *context) {
2292 results.
add<SimplifyDimOfAllocOp>(context);
2300 Attribute
target, CompilationTarget format,
2301 StringAttr
object, DictionaryAttr properties,
2302 KernelTableAttr kernels) {
2304 return emitError() <<
"the target attribute cannot be null";
2305 if (
target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2307 return emitError() <<
"the target attribute must implement or promise the "
2308 "`gpu::TargetAttrInterface`";
2312ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2313 StringAttr &
object) {
2314 std::optional<CompilationTarget> formatResult;
2315 StringRef enumKeyword;
2318 formatResult = CompilationTarget::Fatbin;
2319 if (!formatResult &&
2321 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2323 return odsParser.
emitError(loc,
"expected an equal sign");
2325 return odsParser.
emitError(loc,
"expected keyword for GPU object format");
2326 FailureOr<StringAttr> objectResult =
2327 FieldParser<StringAttr>::parse(odsParser);
2328 if (
failed(objectResult))
2330 "failed to parse GPU_ObjectAttr parameter "
2331 "'object' which is to be a `StringAttr`");
2332 format = *formatResult;
2333 object = *objectResult;
2337void printObject(AsmPrinter &odsParser, CompilationTarget format,
2338 StringAttr
object) {
2339 if (format != CompilationTarget::Fatbin)
2340 odsParser << stringifyEnum(format) <<
" = ";
2341 odsParser << object;
2354 if (
auto intAttr = mlir::dyn_cast<IntegerAttr>(
target)) {
2355 if (intAttr.getInt() < 0) {
2356 return emitError() <<
"the object index must be positive";
2358 }
else if (!
target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2360 <<
"the target attribute must be a GPU Target attribute";
2370LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2371 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2372 return emitOpError() <<
"must be inside an op with symbol table";
2374 MemRefType memrefType = getResultMemref().getType();
2376 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2378 << gpu::AddressSpaceAttr::getMnemonic() <<
"<"
2379 << stringifyEnum(gpu::AddressSpace::Workgroup) <<
">";
2381 if (memrefType.hasStaticShape()) {
2382 return emitOpError() <<
"result memref type must be memref<?xi8, "
2383 "#gpu.address_space<workgroup>>";
2392void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2393 p <<
"(" << getLaneid() <<
")";
2395 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2396 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2397 p <<
"[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() <<
"]";
2399 if (!getArgs().empty())
2400 p <<
" args(" << getArgs() <<
" : " << getArgs().getTypes() <<
")";
2401 if (!getResults().empty())
2402 p <<
" -> (" << getResults().getTypes() <<
')';
2406 !getResults().empty());
2410ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2411 OperationState &
result) {
2413 result.regions.reserve(1);
2414 Region *warpRegion =
result.addRegion();
2417 OpAsmParser::UnresolvedOperand laneId;
2429 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2436 llvm::SMLoc inputsOperandsLoc;
2437 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2438 SmallVector<Type> inputTypes;
2448 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2459 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder,
result.location);
2467void WarpExecuteOnLane0Op::getSuccessorRegions(
2468 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
2475 regions.push_back(RegionSuccessor(&getWarpRegion()));
2478ValueRange WarpExecuteOnLane0Op::getSuccessorInputs(RegionSuccessor successor) {
2481void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2484 build(builder,
result, resultTypes, laneId, warpSize,
2488void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &
result,
2492 result.addOperands(laneId);
2493 result.addAttribute(getAttributeNames()[0],
2495 result.addTypes(resultTypes);
2496 result.addOperands(args);
2497 assert(args.size() == blockArgTypes.size());
2498 OpBuilder::InsertionGuard guard(builder);
2499 Region *warpRegion =
result.addRegion();
2501 for (
auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2510 if (expanded == distributed)
2512 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2513 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2514 if (!expandedVecType || !distributedVecType)
2515 return op->
emitOpError(
"expected vector type for distributed operands.");
2516 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2517 expandedVecType.getElementType() != distributedVecType.getElementType())
2519 "expected distributed vectors to have same rank and element type.");
2522 for (
int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2523 int64_t eDim = expandedVecType.getDimSize(i);
2524 int64_t dDim = distributedVecType.getDimSize(i);
2527 if (eDim % dDim != 0)
2529 <<
"expected expanded vector dimension #" << i <<
" (" << eDim
2530 <<
") to be a multipler of the distributed vector dimension ("
2532 scales[i] = eDim / dDim;
2534 if (llvm::product_of(scales) != warpSize)
2536 <<
"incompatible distribution dimensions from " << expandedVecType
2537 <<
" to " << distributedVecType <<
" with warp size = " << warpSize;
2542LogicalResult WarpExecuteOnLane0Op::verify() {
2543 if (getArgs().size() != getWarpRegion().getNumArguments())
2545 "expected same number op arguments and block arguments.");
2546 auto yield = dyn_cast<gpu::YieldOp>(getBody()->getTerminator());
2548 return emitOpError(
"expected body to be terminated with 'gpu.yield'");
2549 if (yield.getNumOperands() != getNumResults())
2551 "expected same number of yield operands and return values.");
2552 int64_t warpSize = getWarpSize();
2553 for (
auto [regionArg, arg] :
2554 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2556 warpSize, getOperation())))
2559 for (
auto [yieldOperand,
result] :
2560 llvm::zip_equal(yield.getOperands(), getResults())) {
2562 warpSize, getOperation())))
2567bool WarpExecuteOnLane0Op::areTypesCompatible(Type
lhs, Type
rhs) {
2572gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2573 return cast<gpu::YieldOp>(getBody()->getTerminator());
2580void gpu::SubgroupBroadcastOp::inferResultRanges(
2581 ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
2582 setResultRange(getResult(), argRanges.front());
2586 switch (getBroadcastType()) {
2587 case BroadcastType::first_active_lane:
2591 case BroadcastType::specific_lane:
2595 llvm_unreachable(
"Unknown BroadcastType");
2598LogicalResult gpu::SubgroupBroadcastOp::verify() {
2599 switch (getBroadcastType()) {
2600 case BroadcastType::first_active_lane:
2603 <<
"lane can only be specified for `specific_lane` broadcast";
2605 case BroadcastType::specific_lane:
2608 <<
"lane must be specified for `specific_lane` broadcast";
2611 llvm_unreachable(
"Unknown BroadcastType");
2614OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor ) {
2616 if (
auto prev = getSrc().getDefiningOp<SubgroupBroadcastOp>())
2617 return prev.getResult();
2632KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2633 DictionaryAttr metadata) {
2634 assert(kernel &&
"invalid kernel");
2635 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2636 kernel.getAllArgAttrs(), metadata);
2641 FunctionOpInterface kernel,
2642 DictionaryAttr metadata) {
2643 assert(kernel &&
"invalid kernel");
2645 kernel.getAllArgAttrs(), metadata);
2649KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs)
const {
2652 NamedAttrList attrList;
2653 if (DictionaryAttr dict = getMetadata())
2656 return KernelMetadataAttr::get(getName(), getFunctionType(),
getArgAttrs(),
2662 StringAttr name, Type functionType,
2663 ArrayAttr argAttrs, DictionaryAttr metadata) {
2665 return emitError() <<
"the kernel name can't be empty";
2667 if (llvm::any_of(argAttrs, [](Attribute attr) {
2668 return !llvm::isa<DictionaryAttr>(attr);
2671 <<
"all attributes in the array must be a dictionary attribute";
2680KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2681 ArrayRef<KernelMetadataAttr> kernels,
2684 assert((!isSorted || llvm::is_sorted(kernels)) &&
2685 "expected a sorted kernel array");
2687 if (isSorted || llvm::is_sorted(kernels))
2688 return Base::get(context, kernels);
2690 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2691 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2692 return Base::get(context, kernelsTmp);
2695KernelTableAttr KernelTableAttr::getChecked(
2697 ArrayRef<KernelMetadataAttr> kernels,
bool isSorted) {
2699 assert((!isSorted || llvm::is_sorted(kernels)) &&
2700 "expected a sorted kernel array");
2702 if (isSorted || llvm::is_sorted(kernels))
2703 return Base::getChecked(
emitError, context, kernels);
2705 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2706 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2707 return Base::getChecked(
emitError, context, kernelsTmp);
2712 ArrayRef<KernelMetadataAttr> kernels) {
2713 if (kernels.size() < 2)
2716 if (std::adjacent_find(kernels.begin(), kernels.end(),
2717 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2718 return l.getName() == r.getName();
2719 }) != kernels.end()) {
2720 return emitError() <<
"expected all kernels to be uniquely named";
2725KernelMetadataAttr KernelTableAttr::lookup(StringRef key)
const {
2727 return found ? *iterator : KernelMetadataAttr();
2730KernelMetadataAttr KernelTableAttr::lookup(StringAttr key)
const {
2732 return found ? *iterator : KernelMetadataAttr();
2812 return CompilationTarget::Fatbin;
2815std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2817 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
options;
2818 llvm::StringSaver stringSaver(
options.first);
2824 if (!opts.empty() && opts.front() ==
'"' && opts.back() ==
'"')
2825 opts.consume_front(
"\""), opts.consume_back(
"\"");
2826 if (!opts.empty() && opts.front() ==
'\'' && opts.back() ==
'\'')
2827 opts.consume_front(
"'"), opts.consume_back(
"'");
2829 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver,
options.second,
2832 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver,
options.second,
2838std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2843std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2845 size_t startPos =
cmdOptions.find(startsWith);
2846 if (startPos == std::string::npos)
2857#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2858#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2860#define GET_ATTRDEF_CLASSES
2861#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2863#define GET_OP_CLASSES
2864#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2866#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
static LogicalResult verifyDistributedType(Type expanded, Type distributed, int64_t warpSize, Operation *op)
Helper check if the distributed vector type is consistent with the expanded type and distributed size...
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices, StringRef keyword)
static LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, PatternRewriter &rewriter)
Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
static std::string getSparseHandleKeyword(SparseHandleKind kind)
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
static StringRef getOperandSegmentSizeAttr()
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
AttrClass getAttrOfType(StringAttr name)
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an efficient unique identifier for a specific C++ type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConcreteType get(MLIRContext *ctx, Args &&...args)
static ConcreteType getChecked(const Location &loc, Args &&...args)
ImplType * getImpl() const
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Type getElementType() const
Get elementType of a single element.
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
unsigned getNumDims() const
Get number of dims.
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
function_ref< void(llvm::Module &)> optimizedLlvmIRCallback
Callback invoked with LLVM IR for the device module after LLVM optimizations but before codegen.
function_ref< void(StringRef)> getISACallback() const
Returns the callback invoked with the target ISA for the device, for example PTX assembly.
TypeID getTypeID() const
Returns the typeID.
std::string toolkitPath
Path to the target toolkit.
SymbolTable * getSymbolTable() const
Returns the result of the getSymbolTableCallback callback or a nullptr if no callback was provided.
StringRef getELFSection() const
Returns the ELF section.
StringRef getCmdOptions() const
Returns the command line options.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
function_ref< void(StringRef)> isaCallback
Callback invoked with the target ISA for the device, for example PTX assembly.
CompilationTarget compilationTarget
Compilation process target format.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
function_ref< void(llvm::Module &)> initialLlvmIRCallback
Callback invoked with the initial LLVM IR for the device module.
ArrayRef< Attribute > getLibrariesToLink() const
Returns the LLVM libraries to link to.
TargetOptions(StringRef toolkitPath={}, ArrayRef< Attribute > librariesToLink={}, StringRef cmdOptions={}, StringRef elfSection={}, CompilationTarget compilationTarget=getDefaultCompilationTarget(), function_ref< SymbolTable *()> getSymbolTableCallback={}, function_ref< void(llvm::Module &)> initialLlvmIRCallback={}, function_ref< void(llvm::Module &)> linkedLlvmIRCallback={}, function_ref< void(llvm::Module &)> optimizedLlvmIRCallback={}, function_ref< void(StringRef)> isaCallback={})
Constructor initializing the toolkit path, the list of files to link to, extra command line options,...
function_ref< void(llvm::Module &)> getOptimizedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after LLVM optimizations but before c...
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
StringRef getToolkitPath() const
Returns the toolkit path.
SmallVector< Attribute > librariesToLink
List of files to link with the LLVM module.
function_ref< void(llvm::Module &)> linkedLlvmIRCallback
Callback invoked with LLVM IR for the device module after linking the device libraries.
function_ref< void(llvm::Module &)> getInitialLlvmIRCallback() const
Returns the callback invoked with the initial LLVM IR for the device module.
function_ref< SymbolTable *()> getSymbolTableCallback
Callback for obtaining the parent symbol table of all the GPU modules being serialized.
static CompilationTarget getDefaultCompilationTarget()
Returns the default compilation target: CompilationTarget::Fatbin.
function_ref< void(llvm::Module &)> getLinkedLlvmIRCallback() const
Returns the callback invoked with LLVM IR for the device module after linking the device libraries.
std::string elfSection
ELF Section where the binary needs to be located.
CompilationTarget getCompilationTarget() const
Returns the compilation target.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
std::pair< IteratorT, bool > findAttrSorted(IteratorT first, IteratorT last, StringRef name)
Using llvm::lower_bound requires an extra string comparison to check whether the returned iterator po...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto getChecked(function_ref< InFlightDiagnostic()> emitError, MLIRContext *context, Ts &&...params)
Helper method analogous to get, but uses getChecked when available to allow graceful failure on inval...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
llvm::function_ref< Fn > function_ref
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
UnresolvedOperand ssaName
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Utility class for the GPU dialect to represent triples of Values accessible through ....