32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/AsmParser/Parser.h"
35 #include "llvm/IR/Attributes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/SourceMgr.h"
40 #include "llvm/Support/raw_ostream.h"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
49 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
81 return emitError(
"expects coordinates between 1 to 5 dimension");
84 if (!getIm2colOffsets().empty()) {
87 "to use im2col mode, the tensor has to be at least 3-dimensional");
90 "im2col offsets must be 2 less than number of coordinates");
97 return emitError(
"Maximum 5 coordinates and dimension is supported.");
102 if (getModifier() != LoadCacheModifierKind::CG &&
103 getModifier() != LoadCacheModifierKind::CA)
104 return emitError(
"Only CG and CA cache modifiers are supported.");
105 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
106 return emitError(
"expected byte size to be either 4, 8 or 16.");
107 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
108 return emitError(
"CG cache modifier is only support for 16 bytes copy.");
115 std::optional<mlir::NVVM::MMATypes>
116 MmaOp::inferOperandMMAType(
Type operandElType,
bool isAccumulator) {
119 if (operandElType.
isF64())
120 return NVVM::MMATypes::f64;
121 if (operandElType.
isF16() || operandElType == half2Type)
122 return NVVM::MMATypes::f16;
123 if (operandElType.
isF32() && isAccumulator)
124 return NVVM::MMATypes::f32;
125 if (operandElType.
isF32() && !isAccumulator)
126 return NVVM::MMATypes::tf32;
127 if (llvm::isa<IntegerType>(operandElType)) {
129 return NVVM::MMATypes::s32;
133 if (
auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
134 if (structType.getBody().empty())
136 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
143 return (type == MMATypes::u4 || type == MMATypes::s4);
147 return (type == MMATypes::u8 || type == MMATypes::s8);
152 type == MMATypes::s32;
155 MMATypes MmaOp::accumPtxType() {
156 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
157 getODSOperands(2).getTypes().front(),
true);
158 assert(val.has_value() &&
"accumulator PTX type should always be inferrable");
162 MMATypes MmaOp::resultPtxType() {
163 std::optional<mlir::NVVM::MMATypes> val =
164 inferOperandMMAType(getResult().getType(),
true);
165 assert(val.has_value() &&
"result PTX type should always be inferrable");
171 struct OperandFragment {
172 StringRef operandName;
173 StringRef ptxTypeAttr;
175 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
176 : operandName(name), ptxTypeAttr(ptxTypeName) {}
179 std::array<OperandFragment, 3> frags{
180 OperandFragment(
"A", getMultiplicandAPtxTypeAttrName()),
181 OperandFragment(
"B", getMultiplicandBPtxTypeAttrName()),
182 OperandFragment(
"C",
"")};
184 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
186 for (
unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
187 auto &frag = frags[fragIdx];
188 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
189 for (
auto operandIdx = varOperandSpec.first;
190 operandIdx < varOperandSpec.first + varOperandSpec.second;
192 frag.regs.push_back(this->getOperand(operandIdx));
193 if (operandIdx == 0) {
194 regTypes.push_back(this->getOperand(operandIdx).getType());
197 std::optional<MMATypes> inferredType =
198 inferOperandMMAType(regTypes.back(), fragIdx >= 2);
200 ignoreAttrNames.push_back(frag.ptxTypeAttr);
203 auto printMmaOperand = [&](
const OperandFragment &frag) ->
void {
204 p <<
" " << frag.operandName;
210 for (
const auto &frag : frags) {
211 printMmaOperand(frag);
219 frags[1].regs[0].getType(),
220 frags[2].regs[0].getType()},
229 std::optional<MMAIntOverflow> intOverflow,
230 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
231 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
233 assert(shape.size() == 3 &&
"expected shape to have size 3 (m, n, k)");
236 "shape", builder.
getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
242 if (multiplicandPtxTypes) {
248 if (
auto res = inferOperandMMAType(operandA[0].getType(),
false))
250 if (
auto res = inferOperandMMAType(operandB[0].getType(),
false))
254 if (multiplicandLayouts) {
264 if (intOverflow.has_value())
267 if (b1Op.has_value())
272 MmaOp::getOperandSegmentSizeAttr(),
274 static_cast<int32_t>(operandB.size()),
275 static_cast<int32_t>(operandC.size())}));
283 struct OperandFragment {
284 std::optional<MMATypes> elemtype;
290 std::array<OperandFragment, 4> frags;
295 auto parseMmaOperand = [&](StringRef operandName,
307 if (parseMmaOperand(
"A", frags[0]).
failed())
309 if (parseMmaOperand(
"B", frags[1]).
failed())
311 if (parseMmaOperand(
"C", frags[2]).
failed())
326 if (operandTypes.size() != 3)
329 "expected one type for each operand segment but got " +
330 Twine(operandTypes.size()) +
" types");
332 auto &frag = frags[iter.index()];
333 frag.regTypes.resize(frag.regs.size(), iter.value());
338 inferOperandMMAType(frag.regTypes[0], iter.index() < 2);
344 frags[3].elemtype = inferOperandMMAType(resultType,
true);
346 std::array<StringRef, 2> names{
"multiplicandAPtxType",
347 "multiplicandBPtxType"};
348 for (
unsigned idx = 0; idx < names.size(); idx++) {
349 const auto &frag = frags[idx];
350 std::optional<NamedAttribute> attr = namedAttributes.
getNamed(names[idx]);
351 if (!frag.elemtype.has_value() && !attr.has_value()) {
354 "attribute " + names[idx] +
355 " is not provided explicitly and cannot be inferred");
357 if (!attr.has_value())
363 if (!namedAttributes.
empty())
367 static_cast<int32_t>(frags[0].regs.size()),
368 static_cast<int32_t>(frags[1].regs.size()),
369 static_cast<int32_t>(frags[2].regs.size()),
381 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
387 auto f16x2x2StructTy =
394 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
395 getShapeAttr().getK()};
401 AllowedShapes allowedShapes;
402 AllowedTypes expectedA;
403 AllowedTypes expectedB;
404 AllowedTypes expectedC;
409 if (mmaShape[0] == 16) {
411 Type multiplicandFragType;
412 switch (*getMultiplicandAPtxType()) {
415 multiplicandFragType = i32Ty;
417 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
422 multiplicandFragType = f16x2Ty;
423 expectedResult.push_back(f16x2x2StructTy);
424 expectedResult.push_back(f32x4StructTy);
438 return emitError(
"invalid shape or multiplicand type: " +
439 stringifyEnum(getMultiplicandAPtxType().value()));
443 expectedResult.push_back(s32x4StructTy);
444 expectedC.emplace_back(4, i32Ty);
445 multiplicandFragType = i32Ty;
447 expectedC.emplace_back(2, f16x2Ty);
448 expectedC.emplace_back(4, f32Ty);
451 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
452 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
453 expectedA.emplace_back(unitA, multiplicandFragType);
454 expectedB.emplace_back(unitB, multiplicandFragType);
455 allowedShapes.push_back({16, 8, kFactor});
456 allowedShapes.push_back({16, 8, kFactor * 2});
460 if (mmaShape[0] == 8) {
461 if (*getMultiplicandAPtxType() == MMATypes::f16) {
462 expectedA.emplace_back(2, f16x2Ty);
463 expectedB.emplace_back(2, f16x2Ty);
464 expectedResult.push_back(f16x2x4StructTy);
465 expectedResult.push_back(f32x8StructTy);
466 expectedC.emplace_back(4, f16x2Ty);
467 expectedC.emplace_back(8, f32Ty);
468 allowedShapes.push_back({8, 8, 4});
470 if (*getMultiplicandAPtxType() == MMATypes::f64) {
472 expectedA.emplace_back(1, f64Ty);
473 expectedB.emplace_back(1, f64Ty);
474 expectedC.emplace_back(2, f64Ty);
478 allowedShapes.push_back({8, 8, 4});
481 expectedA.push_back({i32Ty});
482 expectedB.push_back({i32Ty});
483 expectedC.push_back({i32Ty, i32Ty});
484 expectedResult.push_back(s32x2StructTy);
486 allowedShapes.push_back({8, 8, 32});
488 allowedShapes.push_back({8, 8, 16});
489 if (getMultiplicandAPtxType().value() == MMATypes::b1)
490 allowedShapes.push_back({8, 8, 128});
494 std::string errorMessage;
495 llvm::raw_string_ostream errorStream(errorMessage);
498 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
499 !llvm::is_contained(allowedShapes, mmaShape)) {
500 errorStream <<
"unimplemented variant for MMA shape <";
501 llvm::interleaveComma(mmaShape, errorStream);
503 return emitOpError(errorMessage);
507 std::array<StringRef, 3> operandNames{
"A",
"B",
"C"};
510 auto spec = this->getODSOperandIndexAndLength(iter.index());
512 operand_type_begin() + spec.first +
514 bool match = llvm::is_contained(iter.value(), operandTySeg);
517 errorStream <<
"Could not match types for the "
518 << operandNames[iter.index()]
519 <<
" operands; expected one of ";
520 for (
const auto &x : iter.value()) {
521 errorStream << x.size() <<
"x" << x[0] <<
" ";
523 errorStream <<
"but got ";
524 llvm::interleaveComma(operandTySeg, errorStream);
525 return emitOpError(errorStream.str());
530 if (!llvm::any_of(expectedResult, [&](
Type expectedResultType) {
531 return expectedResultType == getResult().getType();
534 <<
"Could not match allowed types for the result; expected one of ";
535 llvm::interleaveComma(expectedResult, errorStream);
536 errorStream <<
" but got " << getResult().getType();
537 return emitOpError(errorStream.str());
541 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
542 return emitOpError(
"op requires " + getB1OpAttrName().strref() +
550 if (!getIntOverflowBehavior())
551 return emitOpError(
"op requires " +
552 getIntOverflowBehaviorAttrName().strref() +
560 if (!(*this)->getAttrOfType<UnitAttr>(
"return_value_and_is_valid"))
562 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
563 auto elementType = (type && type.getBody().size() == 2)
564 ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
566 if (!elementType || elementType.getWidth() != 1)
567 return emitError(
"expected return type to be a two-element struct with "
568 "i1 as the second element");
573 NVVM::MMAFrag frag,
int nRow,
576 unsigned numberElements = 0;
580 if (type == NVVM::MMATypes::f16) {
582 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
586 }
else if (type == NVVM::MMATypes::f32) {
589 }
else if (type == NVVM::MMATypes::tf32) {
592 }
else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
594 int parallelSize = 0;
595 if (frag == NVVM::MMAFrag::a)
597 if (frag == NVVM::MMAFrag::b)
601 if (parallelSize == 16)
604 else if (parallelSize == 8)
606 else if (parallelSize == 32)
608 }
else if (type == NVVM::MMATypes::s32) {
612 assert(numberElements != 0 && elementType !=
nullptr);
613 return std::make_pair(elementType, numberElements);
616 static std::pair<mlir::Type, unsigned>
620 if (frag == NVVM::MMAFrag::a) {
623 }
else if (frag == NVVM::MMAFrag::b) {
630 assert(nRow && nCol);
635 unsigned addressSpace =
636 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
639 return emitOpError(
"expected source pointer in memory "
642 if (NVVM::WMMALoadOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
643 getEltype(), getFrag()) == 0)
644 return emitOpError() <<
"invalid attribute combination";
649 if (getType() != dstType)
650 return emitOpError(
"expected destination type is a structure of ")
651 << typeInfo.second <<
" elements of type " << typeInfo.first;
656 unsigned addressSpace =
657 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
660 return emitOpError(
"expected operands to be a source pointer in memory "
663 if (NVVM::WMMAStoreOp::getIntrinsicID(
getM(),
getN(), getK(), getLayout(),
665 return emitOpError() <<
"invalid attribute combination";
668 if (getArgs().size() != typeInfo.second)
669 return emitOpError() <<
"expected " << typeInfo.second <<
" data operands";
670 if (llvm::any_of(getArgs(), [&typeInfo](
Value operands) {
671 return operands.
getType() != typeInfo.first;
673 return emitOpError() <<
"expected data operands of type " << typeInfo.first;
678 if (NVVM::WMMAMmaOp::getIntrinsicID(
getM(),
getN(), getK(), getLayoutA(),
679 getLayoutB(), getEltypeA(),
681 return emitOpError() <<
"invalid attribute combination";
689 arguments.append(typeInfoA.second, typeInfoA.first);
690 arguments.append(typeInfoB.second, typeInfoB.first);
691 arguments.append(typeInfoC.second, typeInfoC.first);
692 unsigned numArgs = arguments.size();
693 if (getArgs().size() != numArgs)
694 return emitOpError() <<
"expected " << numArgs <<
" arguments";
695 for (
unsigned i = 0; i < numArgs; i++) {
696 if (getArgs()[i].getType() != arguments[i])
697 return emitOpError() <<
"expected argument " << i <<
" to be of type "
702 if (getType() != dstType)
703 return emitOpError(
"expected destination type is a structure of ")
704 << typeInfoC.second <<
" elements of type " << typeInfoC.first;
709 unsigned addressSpace =
710 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
712 return emitOpError(
"expected source pointer in memory space 3");
714 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
715 return emitOpError(
"expected num attribute to be 1, 2 or 4");
718 if (getNum() == 1 && getType() != i32)
719 return emitOpError(
"expected destination type is i32");
720 if (getNum() == 2 || getNum() == 4) {
723 if (getType() != dstType)
724 return emitOpError(
"expected destination type is a structure of ")
725 << getNum() <<
" elements of type i32";
731 unsigned addressSpace =
732 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
734 return emitOpError(
"expected source pointer in memory space 3");
736 int numMatrix = getSources().size();
737 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
738 return emitOpError(
"expected num attribute to be 1, 2 or 4");
744 if (typeA == NVVM::WGMMATypes::tf32)
746 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
748 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
750 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
752 if (typeA == NVVM::WGMMATypes::b1)
758 NVVM::WGMMATypes typeA,
759 NVVM::WGMMATypes typeB) {
761 case NVVM::WGMMATypes::f16:
762 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
763 typeB == NVVM::WGMMATypes::f16)
766 case NVVM::WGMMATypes::tf32:
767 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
770 case NVVM::WGMMATypes::u8:
771 case NVVM::WGMMATypes::s8:
772 if (typeD == NVVM::WGMMATypes::s32 &&
773 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
776 case NVVM::WGMMATypes::b1:
777 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
780 case NVVM::WGMMATypes::bf16:
781 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
782 typeB == NVVM::WGMMATypes::bf16)
785 case NVVM::WGMMATypes::e4m3:
786 case NVVM::WGMMATypes::e5m2:
787 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
788 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
791 case WGMMATypes::f32:
792 case WGMMATypes::s32:
793 llvm_unreachable(
"unsupported input types");
801 72, 80, 88, 96, 104, 112, 120, 128,
802 136, 144, 152, 160, 168, 176, 184, 192,
803 200, 208, 216, 224, 232, 240, 248, 256};
805 80, 96, 112, 128, 144, 160,
806 176, 192, 208, 224, 240, 256};
808 case WGMMATypes::f16:
809 case WGMMATypes::tf32:
810 case WGMMATypes::bf16:
811 case WGMMATypes::e4m3:
812 case WGMMATypes::e5m2:
813 if (llvm::is_contained(allowedN, sizeN))
819 if (llvm::is_contained(allowedNshort, sizeN))
822 case WGMMATypes::f32:
823 case WGMMATypes::s32:
824 llvm_unreachable(
"unsupported input types");
831 Value outValue = getResults();
832 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.
getType());
834 return emitOpError() <<
"expected results to be struct";
835 int outputSize = stype.getBody().size();
836 WGMMATypes typeD = getTypeD();
837 WGMMATypes typeA = getTypeA();
838 WGMMATypes typeB = getTypeB();
840 for (
Type t : stype.getBody()) {
841 if (t != stype.getBody().front())
843 <<
"all elements in struct must be same type but there is " << t;
846 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
847 typeD != WGMMATypes::s32) {
848 return emitOpError() <<
"does not support the given output type "
849 << NVVM::stringifyWGMMATypes(typeD);
851 if (typeD == WGMMATypes::s32 &&
852 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
853 return emitOpError() <<
"has s32 output, scaleA and scaleB cannot be neg";
857 return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
858 <<
" += " << NVVM::stringifyWGMMATypes(typeA) <<
" * "
859 << NVVM::stringifyWGMMATypes(typeB)
860 <<
", it is not supported.";
865 return emitOpError() <<
"shape 'm' must be 64";
870 return emitOpError() <<
"shape 'k' must be " << allowedK.value()
871 <<
" for input type "
872 << NVVM::stringifyWGMMATypes(typeA);
876 return emitOpError() <<
"has input type "
877 << NVVM::stringifyWGMMATypes(typeA) <<
" n is set to "
878 <<
getShape().getN() <<
", it is not supported.";
882 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
883 (getLayoutA() == mlir::NVVM::MMALayout::col ||
884 getLayoutB() == mlir::NVVM::MMALayout::col)) {
886 <<
"given layouts layout_a = " << stringifyMMALayout(getLayoutA())
887 <<
" and layout_b = " << stringifyMMALayout(getLayoutB())
888 <<
" for input types " << stringifyWGMMATypes(typeA) <<
" and "
889 << stringifyWGMMATypes(typeB)
890 <<
" requires transpose. However, this is only supported for: "
891 << stringifyMMATypes(MMATypes::f16) <<
" and "
892 << stringifyMMATypes(MMATypes::bf16);
896 int expectedOutput = 0;
897 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
898 expectedOutput =
getShape().getN() / 2;
899 if (typeD == WGMMATypes::f16)
900 expectedOutput =
getShape().getN() / 4;
901 if (outputSize != expectedOutput) {
902 return emitOpError() <<
"results " << expectedOutput
903 <<
", however output struct has " << outputSize
907 if (typeD != WGMMATypes::s32 &&
908 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
909 NVVM::MMAIntOverflow::satfinite) {
911 <<
" `satfinite` can be only used with s32 accumulator, however "
912 "the current accumulator is "
913 << NVVM::stringifyWGMMATypes(typeD);
919 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
922 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
924 StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
926 int expectedOutputRegisters = 0;
927 if (getTypeD() == WGMMATypes::f16)
928 expectedOutputRegisters =
getShape().getN() / 4;
930 expectedOutputRegisters =
getShape().getN() / 2;
933 llvm::raw_string_ostream ss(ptx);
938 << ((expectedOutputRegisters * 2) + 2)
940 "wgmma.mma_async.sync.aligned.m"
941 << m <<
"n" << n <<
"k" << k <<
"." << outputTypeName <<
"."
942 << stringifyWGMMATypes(getTypeA()) <<
"."
943 << stringifyWGMMATypes(getTypeB());
944 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
945 NVVM::MMAIntOverflow::satfinite)
949 for (; regCnt < expectedOutputRegisters; ++regCnt) {
951 if (regCnt != expectedOutputRegisters - 1)
957 regCnt = (regCnt * 2);
958 ss <<
" $" << (regCnt) <<
"," <<
" $" << (regCnt + 1) <<
"," <<
" p";
959 if (getTypeD() != WGMMATypes::s32) {
960 ss <<
", $" << (regCnt + 3) <<
", $" << (regCnt + 4);
964 ss <<
", $" << (regCnt + 5) <<
", $" << (regCnt + 6);
972 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
976 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
983 asmValues.push_back({makeConstantI32(rewriter,
static_cast<int>(getScaleD())),
985 if (getTypeD() != WGMMATypes::s32) {
987 {makeConstantI32(rewriter,
988 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
991 {makeConstantI32(rewriter,
992 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
997 {makeConstantI32(rewriter,
static_cast<int>(getLayoutA())),
1000 {makeConstantI32(rewriter, 1 -
static_cast<int>(getLayoutB())),
1005 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1006 return emitOpError() <<
"async_shared fence requires space attribute";
1008 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1009 return emitOpError() <<
"only async_shared fence can have space attribute";
1015 if (getRegCount() % 8)
1016 return emitOpError(
"new register size must be multiple of 8");
1017 if (getRegCount() < 24 || getRegCount() > 256)
1018 return emitOpError(
"new register size must be in between 24 to 256");
1023 if (getNumberOfThreads() && !getBarrierId())
1025 "barrier id is missing, it should be set between 0 to 15");
1034 void NVVMDialect::initialize() {
1037 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1040 #define GET_ATTRDEF_LIST
1041 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1046 allowUnknownOperations();
1047 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1048 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1053 StringAttr attrName = attr.
getName();
1055 if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1056 if (!isa<LLVM::LLVMFuncOp>(op)) {
1057 return op->
emitError() <<
"'" << NVVMDialect::getKernelFuncAttrName()
1058 <<
"' attribute attached to unexpected op";
1062 if (attrName == NVVMDialect::getMaxntidAttrName() ||
1063 attrName == NVVMDialect::getReqntidAttrName()) {
1064 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.
getValue());
1065 if (!values || values.empty() || values.size() > 3)
1068 <<
"' attribute must be integer array with maximum 3 index";
1071 if (attrName == NVVMDialect::getMinctasmAttrName() ||
1072 attrName == NVVMDialect::getMaxnregAttrName()) {
1073 if (!llvm::dyn_cast<IntegerAttr>(attr.
getValue()))
1075 <<
"'" << attrName <<
"' attribute must be integer constant";
1082 unsigned regionIndex,
1085 auto funcOp = dyn_cast<FunctionOpInterface>(op);
1089 bool isKernel = op->
hasAttr(NVVMDialect::getKernelFuncAttrName());
1090 StringAttr attrName = argAttr.
getName();
1091 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1095 <<
"' attribute must be present only on kernel arguments";
1097 if (!isa<UnitAttr>(argAttr.
getValue()))
1098 return op->
emitError() <<
"'" << attrName <<
"' must be a unit attribute";
1099 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1102 <<
"' attribute requires the argument to also have attribute '"
1103 << LLVM::LLVMDialect::getByValAttrName() <<
"'";
1115 int optLevel, StringRef triple, StringRef chip,
1116 StringRef features, DictionaryAttr flags,
1118 if (optLevel < 0 || optLevel > 3) {
1119 emitError() <<
"The optimization level must be a number between 0 and 3.";
1122 if (triple.empty()) {
1123 emitError() <<
"The target triple cannot be empty.";
1127 emitError() <<
"The target chip cannot be empty.";
1131 return attr && mlir::isa<StringAttr>(attr);
1133 emitError() <<
"All the elements in the `link` array must be strings.";
1139 #define GET_OP_CLASSES
1140 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1142 #define GET_ATTRDEF_CLASSES
1143 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::pair< mlir::Type, unsigned > inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA)
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op)
FailureOr< int > getAllowedSizeK(NVVM::WGMMATypes typeA)
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static bool isInt8PtxType(MMATypes type)
static bool isInt4PtxType(MMATypes type)
static bool isIntegerPtxType(MMATypes type)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class provides support for representing a failure result, or a valid value of type T.
This class represents a diagnostic that is inflight and set to be reported.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
std::pair< mlir::Type, unsigned > inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, int nRow, int nCol, mlir::MLIRContext *context)
Return the element type and number of elements associated with a wmma matrix of given chracteristics.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< Type, 4 > types
Types of the results of this operation.