31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
80 return emitError(
"expects coordinates between 1 to 5 dimension");
83 if (!getIm2colOffsets().empty()) {
86 "to use im2col mode, the tensor has to be at least 3-dimensional");
89 "im2col offsets must be 2 less than number of coordinates");
96 return emitError(
"Maximum 5 coordinates and dimension is supported.");
101 if (getModifier() != LoadCacheModifierKind::CG &&
102 getModifier() != LoadCacheModifierKind::CA)
103 return emitError(
"Only CG and CA cache modifiers are supported.");
104 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
105 return emitError(
"expected byte size to be either 4, 8 or 16.");
106 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
107 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
114 std::optional<mlir::NVVM::MMATypes>
115 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
118 if (operandElType.
isF64())
119 return NVVM::MMATypes::f64;
120 if (operandElType.
isF16() || operandElType == half2Type)
121 return NVVM::MMATypes::f16;
122 if (operandElType.
isF32() && isAccumulator)
123 return NVVM::MMATypes::f32;
124 if (operandElType.
isF32() && !isAccumulator)
125 return NVVM::MMATypes::tf32;
126 if (llvm::isa<IntegerType>(operandElType)) {
128 return NVVM::MMATypes::s32;
132 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
133 if (structType.getBody().empty())
135 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
142 return (type == MMATypes::u4 || type == MMATypes::s4);
146 return (type == MMATypes::u8 || type == MMATypes::s8);
151 type == MMATypes::s32;
154 MMATypes MmaOp::accumPtxType() {
155 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
156 getODSOperands(2).getTypes().front(),
true);
157 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
161 MMATypes MmaOp::resultPtxType() {
162 std::optional<mlir::NVVM::MMATypes> val =
163 inferOperandMMAType(getResult().
getType(),
true);
164 assert(val.has_value() &&
"result PTX type should always be inferrable");
170 struct OperandFragment {
171 StringRef operandName;
172 StringRef ptxTypeAttr;
174 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
175 : operandName(name), ptxTypeAttr(ptxTypeName) {}
178 std::array<OperandFragment, 3> frags{
179 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
180 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
181 OperandFragment(
"C",
"")};
183 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
185 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
186 auto &frag = frags[fragIdx];
187 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
188 for (
auto operandIdx = varOperandSpec.first;
189 operandIdx < varOperandSpec.first + varOperandSpec.second;
191 frag.regs.push_back(this->getOperand(operandIdx));
192 if (operandIdx == 0) {
193 regTypes.push_back(this->getOperand(operandIdx).
getType());
196 std::optional<MMATypes> inferredType =
197 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
199 ignoreAttrNames.push_back(frag.ptxTypeAttr);
202 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
203 p <<
" " << frag.operandName;
209 for (
const auto &frag : frags) {
210 printMmaOperand(frag);
228 std::optional<MMAIntOverflow> intOverflow,
229 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
230 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
232 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
235 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
241 if (multiplicandPtxTypes) {
247 if (
auto res = inferOperandMMAType(operandA[0].
getType(),
false))
249 if (
auto res = inferOperandMMAType(operandB[0].
getType(),
false))
253 if (multiplicandLayouts) {
263 if (intOverflow.has_value())
266 if (b1Op.has_value())
271 MmaOp::getOperandSegmentSizeAttr(),
273 static_cast<int32_t>(operandB.size()),
274 static_cast<int32_t>(operandC.size())}));
282 struct OperandFragment {
283 std::optional<MMATypes> elemtype;
289 std::array<OperandFragment, 4> frags;
294 auto parseMmaOperand = [&](StringRef operandName,
295 OperandFragment &frag) -> LogicalResult {
306 if (parseMmaOperand(
"A", frags[0]).failed())
308 if (parseMmaOperand(
"B", frags[1]).failed())
310 if (parseMmaOperand(
"C", frags[2]).failed())
325 if (operandTypes.size() != 3)
328 "expected one type for each operand segment but got " +
329 Twine(operandTypes.size()) +
" types");
331 auto &frag = frags[iter.index()];
332 frag.regTypes.resize(frag.regs.size(), iter.value());
337 inferOperandMMAType(frag.regTypes[0], iter.index() < 2);
343 frags[3].elemtype = inferOperandMMAType(resultType,
true);
345 std::array<StringRef, 2> names{
"multiplicandAPtxType",
346 "multiplicandBPtxType"};
347 for (
unsigned idx = 0; idx < names.size(); idx++) {
348 const auto &frag = frags[idx];
349 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
350 if (!frag.elemtype.has_value() && !attr.has_value()) {
353 "attribute " + names[idx] +
354 " is not provided explicitly and cannot be inferred");
356 if (!attr.has_value())
362 if (!namedAttributes.
empty())
366 static_cast<int32_t>(frags[0].regs.size()),
367 static_cast<int32_t>(frags[1].regs.size()),
368 static_cast<int32_t>(frags[2].regs.size()),
380 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
386 auto f16x2x2StructTy =
393 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
394 getShapeAttr().getK()};
400 AllowedShapes allowedShapes;
401 AllowedTypes expectedA;
402 AllowedTypes expectedB;
403 AllowedTypes expectedC;
408 if (mmaShape[0] == 16) {
410 Type multiplicandFragType;
411 switch (*getMultiplicandAPtxType()) {
414 multiplicandFragType = i32Ty;
416 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
421 multiplicandFragType = f16x2Ty;
422 expectedResult.push_back(f16x2x2StructTy);
423 expectedResult.push_back(f32x4StructTy);
437 return emitError(
"invalid shape or multiplicand type: " +
438 stringifyEnum(getMultiplicandAPtxType().value()));
442 expectedResult.push_back(s32x4StructTy);
443 expectedC.emplace_back(4, i32Ty);
444 multiplicandFragType = i32Ty;
446 expectedC.emplace_back(2, f16x2Ty);
447 expectedC.emplace_back(4, f32Ty);
450 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
451 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
452 expectedA.emplace_back(unitA, multiplicandFragType);
453 expectedB.emplace_back(unitB, multiplicandFragType);
454 allowedShapes.push_back({16, 8, kFactor});
455 allowedShapes.push_back({16, 8, kFactor * 2});
459 if (mmaShape[0] == 8) {
460 if (*getMultiplicandAPtxType() == MMATypes::f16) {
461 expectedA.emplace_back(2, f16x2Ty);
462 expectedB.emplace_back(2, f16x2Ty);
463 expectedResult.push_back(f16x2x4StructTy);
464 expectedResult.push_back(f32x8StructTy);
465 expectedC.emplace_back(4, f16x2Ty);
466 expectedC.emplace_back(8, f32Ty);
467 allowedShapes.push_back({8, 8, 4});
469 if (*getMultiplicandAPtxType() == MMATypes::f64) {
471 expectedA.emplace_back(1, f64Ty);
472 expectedB.emplace_back(1, f64Ty);
473 expectedC.emplace_back(2, f64Ty);
477 allowedShapes.push_back({8, 8, 4});
480 expectedA.push_back({i32Ty});
481 expectedB.push_back({i32Ty});
482 expectedC.push_back({i32Ty, i32Ty});
483 expectedResult.push_back(s32x2StructTy);
485 allowedShapes.push_back({8, 8, 32});
487 allowedShapes.push_back({8, 8, 16});
488 if (getMultiplicandAPtxType().value() == MMATypes::b1)
489 allowedShapes.push_back({8, 8, 128});
493 std::string errorMessage;
494 llvm::raw_string_ostream errorStream(errorMessage);
497 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
498 !llvm::is_contained(allowedShapes, mmaShape)) {
499 errorStream <<
"unimplemented variant for MMA shape <";
500 llvm::interleaveComma(mmaShape, errorStream);
502 return emitOpError(errorMessage);
506 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
509 auto spec = this->getODSOperandIndexAndLength(iter.index());
511 operand_type_begin() + spec.first +
513 bool match = llvm::is_contained(iter.value(), operandTySeg);
516 errorStream <<
"Could not match types for the "
517 << operandNames[iter.index()]
518 <<
" operands; expected one of ";
519 for (
const auto &x : iter.value()) {
520 errorStream << x.size() <<
"x" << x[0] <<
" ";
522 errorStream <<
"but got ";
523 llvm::interleaveComma(operandTySeg, errorStream);
524 return emitOpError(errorStream.str());
529 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
530 return expectedResultType == getResult().getType();
533 <<
"Could not match allowed types for the result; expected one of ";
534 llvm::interleaveComma(expectedResult, errorStream);
535 errorStream <<
" but got " << getResult().getType();
536 return emitOpError(errorStream.str());
540 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
541 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
549 if (!getIntOverflowBehavior())
550 return emitOpError(
"op requires " +
551 getIntOverflowBehaviorAttrName().strref() +
559 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
561 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(
getType());
562 auto elementType = (type && type.getBody().size() == 2)
563 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
565 if (!elementType || elementType.getWidth() != 1)
566 return emitError(
"expected return type to be a two-element struct with "
567 "i1 as the second element");
572 NVVM::MMAFrag frag,
int nRow,
575 unsigned numberElements = 0;
579 if (type == NVVM::MMATypes::f16) {
581 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
585 }
else if (type == NVVM::MMATypes::f32) {
588 }
else if (type == NVVM::MMATypes::tf32) {
591 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
593 int parallelSize = 0;
594 if (frag == NVVM::MMAFrag::a)
596 if (frag == NVVM::MMAFrag::b)
600 if (parallelSize == 16)
603 else if (parallelSize == 8)
605 else if (parallelSize == 32)
607 }
else if (type == NVVM::MMATypes::s32) {
611 assert(numberElements != 0 && elementType !=
nullptr);
612 return std::make_pair(elementType, numberElements);
615 static std::pair<mlir::Type, unsigned>
619 if (frag == NVVM::MMAFrag::a) {
622 }
else if (frag == NVVM::MMAFrag::b) {
629 assert(nRow && nCol);
634 unsigned addressSpace =
635 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
638 return emitOpError(
"expected source pointer in memory "
641 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
642 getEltype(), getFrag()) == 0)
643 return emitOpError() <<
"invalid attribute combination";
649 return emitOpError(
"expected destination type is a structure of ")
650 << typeInfo.second <<
" elements of type " << typeInfo.first;
655 unsigned addressSpace =
656 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
659 return emitOpError(
"expected operands to be a source pointer in memory "
662 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
664 return emitOpError() <<
"invalid attribute combination";
667 if (getArgs().size() != typeInfo.second)
668 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
669 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
670 return operands.
getType() != typeInfo.first;
672 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
677 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
678 getLayoutB(), getEltypeA(),
680 return emitOpError() <<
"invalid attribute combination";
688 arguments.append(typeInfoA.second, typeInfoA.first);
689 arguments.append(typeInfoB.second, typeInfoB.first);
690 arguments.append(typeInfoC.second, typeInfoC.first);
691 unsigned numArgs = arguments.size();
692 if (getArgs().size() != numArgs)
693 return emitOpError() <<
"expected " << numArgs <<
" arguments";
694 for (
unsigned i = 0; i < numArgs; i++) {
695 if (getArgs()[i].
getType() != arguments[i])
696 return emitOpError() <<
"expected argument " << i <<
" to be of type "
702 return emitOpError(
"expected destination type is a structure of ")
703 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
708 unsigned addressSpace =
709 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
711 return emitOpError(
"expected source pointer in memory space 3");
713 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
714 return emitOpError(
"expected num attribute to be 1, 2 or 4");
717 if (getNum() == 1 &&
getType() != i32)
718 return emitOpError(
"expected destination type is i32");
719 if (getNum() == 2 || getNum() == 4) {
723 return emitOpError(
"expected destination type is a structure of ")
724 << getNum() <<
" elements of type i32";
730 unsigned addressSpace =
731 llvm::cast<LLVM::LLVMPointerType>(getPtr().
getType()).getAddressSpace();
733 return emitOpError(
"expected source pointer in memory space 3");
735 int numMatrix = getSources().size();
736 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
737 return emitOpError(
"expected num attribute to be 1, 2 or 4");
743 if (typeA == NVVM::WGMMATypes::tf32)
745 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
747 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
749 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
751 if (typeA == NVVM::WGMMATypes::b1)
757 NVVM::WGMMATypes typeA,
758 NVVM::WGMMATypes typeB) {
760 case NVVM::WGMMATypes::f16:
761 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
762 typeB == NVVM::WGMMATypes::f16)
765 case NVVM::WGMMATypes::tf32:
766 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
769 case NVVM::WGMMATypes::u8:
770 case NVVM::WGMMATypes::s8:
771 if (typeD == NVVM::WGMMATypes::s32 &&
772 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
775 case NVVM::WGMMATypes::b1:
776 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
779 case NVVM::WGMMATypes::bf16:
780 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
781 typeB == NVVM::WGMMATypes::bf16)
784 case NVVM::WGMMATypes::e4m3:
785 case NVVM::WGMMATypes::e5m2:
786 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
787 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
790 case WGMMATypes::f32:
791 case WGMMATypes::s32:
792 llvm_unreachable(
"unsupported input types");
800 72, 80, 88, 96, 104, 112, 120, 128,
801 136, 144, 152, 160, 168, 176, 184, 192,
802 200, 208, 216, 224, 232, 240, 248, 256};
804 80, 96, 112, 128, 144, 160,
805 176, 192, 208, 224, 240, 256};
807 case WGMMATypes::f16:
808 case WGMMATypes::tf32:
809 case WGMMATypes::bf16:
810 case WGMMATypes::e4m3:
811 case WGMMATypes::e5m2:
812 if (llvm::is_contained(allowedN, sizeN))
818 if (llvm::is_contained(allowedNshort, sizeN))
821 case WGMMATypes::f32:
822 case WGMMATypes::s32:
823 llvm_unreachable(
"unsupported input types");
830 Value outValue = getResults();
831 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
833 return emitOpError() <<
"expected results to be struct";
834 int outputSize = stype.getBody().size();
835 WGMMATypes typeD = getTypeD();
836 WGMMATypes typeA = getTypeA();
837 WGMMATypes typeB = getTypeB();
839 for (
Type t : stype.getBody()) {
840 if (t != stype.getBody().front())
842 <<
"all elements in struct must be same type but there is " << t;
845 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
846 typeD != WGMMATypes::s32) {
847 return emitOpError() <<
"does not support the given output type "
848 << NVVM::stringifyWGMMATypes(typeD);
850 if (typeD == WGMMATypes::s32 &&
851 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
852 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
856 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
857 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
858 << NVVM::stringifyWGMMATypes(typeB)
859 <<
", it is not supported.";
864 return emitOpError() <<
"shape 'm' must be 64";
868 if (failed(allowedK) || allowedK.value() !=
getShape().getK())
869 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
870 <<
" for input type "
871 << NVVM::stringifyWGMMATypes(typeA);
875 return emitOpError() <<
"has input type "
876 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
877 <<
getShape().getN() <<
", it is not supported.";
884 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
885 (getLayoutA() == mlir::NVVM::MMALayout::col ||
886 getLayoutB() == mlir::NVVM::MMALayout::row)) {
888 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
889 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
890 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
891 << stringifyWGMMATypes(typeB)
892 <<
" requires transpose. However, this is only supported for: "
893 << stringifyMMATypes(MMATypes::f16) <<
" and "
894 << stringifyMMATypes(MMATypes::bf16);
898 int expectedOutput = 0;
899 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
900 expectedOutput =
getShape().getN() / 2;
901 if (typeD == WGMMATypes::f16)
902 expectedOutput =
getShape().getN() / 4;
903 if (outputSize != expectedOutput) {
904 return emitOpError() <<
"results " << expectedOutput
905 <<
", however output struct has " << outputSize
909 if (typeD != WGMMATypes::s32 &&
910 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
911 NVVM::MMAIntOverflow::satfinite) {
913 <<
" `satfinite` can be only used with s32 accumulator, however "
914 "the current accumulator is "
915 << NVVM::stringifyWGMMATypes(typeD);
921 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
924 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
926 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
928 int expectedOutputRegisters = 0;
929 if (getTypeD() == WGMMATypes::f16)
930 expectedOutputRegisters =
getShape().getN() / 4;
932 expectedOutputRegisters =
getShape().getN() / 2;
935 llvm::raw_string_ostream ss(ptx);
940 << ((expectedOutputRegisters * 2) + 2)
942 "wgmma.mma_async.sync.aligned.m"
943 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
944 << stringifyWGMMATypes(getTypeA()) <<
"."
945 << stringifyWGMMATypes(getTypeB());
946 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
947 NVVM::MMAIntOverflow::satfinite)
951 for (; regCnt < expectedOutputRegisters; ++regCnt) {
953 if (regCnt != expectedOutputRegisters - 1)
959 regCnt = (regCnt * 2);
960 ss <<
" $" << (regCnt) <<
"," <<
" $" << (regCnt + 1) <<
"," <<
" p";
961 if (getTypeD() != WGMMATypes::s32) {
962 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
966 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
974 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
978 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
985 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
987 if (getTypeD() != WGMMATypes::s32) {
989 {makeConstantI32(rewriter,
990 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
993 {makeConstantI32(rewriter,
994 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
999 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1001 asmValues.push_back(
1002 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1007 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1008 return emitOpError() <<
"async_shared fence requires space attribute";
1010 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1011 return emitOpError() <<
"only async_shared fence can have space attribute";
1017 if (getRegCount() % 8)
1018 return emitOpError(
"new register size must be multiple of 8");
1019 if (getRegCount() < 24 || getRegCount() > 256)
1020 return emitOpError(
"new register size must be in between 24 to 256");
1025 if (getNumberOfThreads() && !getBarrierId())
1027 "barrier id is missing, it should be set between 0 to 15");
1036 void NVVMDialect::initialize() {
1039 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1042 #define GET_ATTRDEF_LIST
1043 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1048 allowUnknownOperations();
1049 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1050 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1053 LogicalResult NVVMDialect::verifyOperationAttribute(
Operation *op,
1055 StringAttr attrName = attr.
getName();
1057 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1058 if (!isa<LLVM::LLVMFuncOp>(op)) {
1059 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1060 <<
"' attribute attached to unexpected op";
1064 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1065 attrName == NVVMDialect::getReqntidAttrName()) {
1066 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1067 if (!values || values.empty() || values.size() > 3)
1070 <<
"' attribute must be integer array with maximum 3 index";
1073 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1074 attrName == NVVMDialect::getMaxnregAttrName()) {
1075 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1077 <<
"'" << attrName <<
"' attribute must be integer constant";
1083 LogicalResult NVVMDialect::verifyRegionArgAttribute(
Operation *op,
1084 unsigned regionIndex,
1087 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1091 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1092 StringAttr attrName = argAttr.
getName();
1093 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1097 <<
"' attribute must be present only on kernel arguments";
1099 if (!isa<UnitAttr>(argAttr.
getValue()))
1100 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1101 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1104 <<
"' attribute requires the argument to also have attribute '"
1105 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1117 int optLevel, StringRef triple, StringRef chip,
1118 StringRef features, DictionaryAttr flags,
1120 if (optLevel < 0 || optLevel > 3) {
1121 emitError() <<
"The optimization level must be a number between 0 and 3.";
1124 if (triple.empty()) {
1125 emitError() <<
"The target triple cannot be empty.";
1129 emitError() <<
"The target chip cannot be empty.";
1133 return attr && mlir::isa<StringAttr>(attr);
1135 emitError() <<
"All the elements in the `link` array must be strings.";
1141 #define GET_OP_CLASSES
1142 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1144 #define GET_ATTRDEF_CLASSES
1145 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class represents a diagnostic that is inflight and set to be reported.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.