10 #include "../PassDetail.h" 25 auto i32Ty = IntegerType::get(ctx, 32);
27 Type f64Ty = Float64Type::get(ctx);
29 Type f32Ty = Float32Type::get(ctx);
31 if (a.getElementType() == f16x2Ty) {
35 if (a.getElementType() == i32x2Ty) {
40 if (a.getElementType() == f64x2Ty) {
43 if (a.getElementType() == f32x2Ty) {
52 return vectorResultType;
75 auto makeConst = [&](int32_t index) ->
Value {
76 return rewriter.
create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
85 if (arrayType.getElementType() == f16x2Ty ||
86 arrayType.getElementType() == f32x1Ty) {
87 for (
unsigned i = 0; i < structType.getBody().size(); i++) {
89 loc, structType.getBody()[i], intrinsicResult,
92 loc, arrayType.getElementType(), el);
93 elements.push_back(el);
101 if (arrayType.getElementType() == i32x2Ty ||
102 arrayType.getElementType() == f64x2Ty ||
103 arrayType.getElementType() == f32x2Ty) {
105 for (
unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
107 rewriter.
create<LLVM::UndefOp>(loc, arrayType.getElementType());
109 loc, structType.getBody()[i * 2], intrinsicResult,
112 loc, structType.getBody()[i * 2 + 1], intrinsicResult,
114 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
116 vec = rewriter.
create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
118 elements.push_back(vec);
123 Value result = rewriter.
create<LLVM::UndefOp>(loc, arrayType);
125 result = rewriter.
create<LLVM::InsertValueOp>(
126 loc, arrayType, result, el.value(),
132 return intrinsicResult;
142 NVVM::MMATypes operandPtxType) {
154 for (
unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
155 Value toUse = rewriter.
create<LLVM::ExtractValueOp>(
160 if (arrayTy.getElementType() == i8x4Ty ||
161 arrayTy.getElementType() == i4x8Ty ||
162 (arrayTy.getElementType() == f32x1Ty &&
163 operandPtxType == NVVM::MMATypes::tf32)) {
172 VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
173 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
174 innerArrayTy.getElementType() == f64Ty ||
175 innerArrayTy.getElementType() == f32Ty)) {
176 for (
unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
177 idx < innerSize; idx++) {
178 result.push_back(rewriter.
create<LLVM::ExtractElementOp>(
180 rewriter.
create<LLVM::ConstantOp>(
185 result.push_back(toUse);
196 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
207 auto vectorResultType = op->getResultTypes()[0].
dyn_cast<VectorType>();
208 if (!vectorResultType) {
212 vectorResultType.getElementType(), vectorResultType.getDimSize(1));
214 int64_t num32BitRegs = vectorResultType.getDimSize(0);
216 Type ldMatrixResultType;
217 if (num32BitRegs > 1) {
224 auto srcMemrefType = op.getSrcMemref().getType().cast<MemRefType>();
226 getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
227 adaptor.getIndices(), rewriter);
228 Value ldMatrixResult = rewriter.
create<NVVM::LdMatrixOp>(
229 loc, ldMatrixResultType, srcPtr,
231 op.getTranspose() ? NVVM::MMALayout::col
232 : NVVM::MMALayout::row);
238 Type finalResultType = typeConverter->convertType(vectorResultType);
239 Value result = rewriter.
create<LLVM::UndefOp>(loc, finalResultType);
240 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
241 Value i32Register = num32BitRegs > 1
242 ? rewriter.
create<LLVM::ExtractValueOp>(
247 rewriter.
create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
248 result = rewriter.
create<LLVM::InsertValueOp>(
261 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
266 auto aType = op.getMatrixA().
getType().
cast<VectorType>();
267 auto cType = op.getMatrixC().getType().cast<VectorType>();
269 int64_t m = op.getMmaShape()[0].cast<IntegerAttr>().getInt();
270 int64_t n = op.getMmaShape()[1].cast<IntegerAttr>().getInt();
271 int64_t k = op.getMmaShape()[2].cast<IntegerAttr>().getInt();
272 std::array<int64_t, 3> gemmShape{m, n, k};
274 NVVM::MMATypes ptxTypeA;
275 NVVM::MMATypes ptxTypeB;
277 cType.getElementType(),
true);
279 return op->emitError(
280 "could not infer the PTX type for the accumulator/result");
284 if (aType.getElementType().isInteger(8)) {
285 ptxTypeA = NVVM::MMATypes::s8;
286 ptxTypeB = NVVM::MMATypes::s8;
287 overflow = NVVM::MMAIntOverflow::satfinite;
288 }
else if (aType.getElementType().isInteger(4)) {
289 ptxTypeA = NVVM::MMATypes::s4;
290 ptxTypeB = NVVM::MMATypes::s4;
291 overflow = NVVM::MMAIntOverflow::satfinite;
292 }
else if (aType.getElementType().isF16()) {
293 ptxTypeA = NVVM::MMATypes::f16;
294 ptxTypeB = NVVM::MMATypes::f16;
295 }
else if (aType.getElementType().isF64()) {
296 ptxTypeA = NVVM::MMATypes::f64;
297 ptxTypeB = NVVM::MMATypes::f64;
298 }
else if (aType.getElementType().isF32()) {
299 ptxTypeA = NVVM::MMATypes::tf32;
300 ptxTypeB = NVVM::MMATypes::tf32;
302 return op->emitError(
"could not deduce operand PTX types");
312 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
314 typeConverter->convertType(op->getResultTypes()[0]));
315 Value intrinsicResult = rewriter.
create<NVVM::MmaOp>(
316 op.getLoc(), intrinsicResTy, matA, matB, matC,
321 std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
323 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
324 NVVM::MMALayout::col});
326 desiredRetTy, intrinsicResult,
332 struct ConvertNVGPUToNVVMPass
333 :
public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
334 ConvertNVGPUToNVVMPass() =
default;
336 void runOnOperation()
override {
342 return converter.
convertType(IntegerType::get(type.getContext(), 32));
349 std::move(patterns))))
354 struct NVGPUAsyncCopyLowering
360 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
363 auto dstMemrefType = op.getDst().
getType().
cast<MemRefType>();
364 Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
365 adaptor.getDstIndices(), rewriter);
366 auto i8Ty = IntegerType::get(op.getContext(), 8);
367 auto dstPointerType =
369 dstPtr = rewriter.
create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
371 auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
373 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(),
374 adaptor.getSrcIndices(), rewriter);
375 auto srcPointerType =
377 scrPtr = rewriter.
create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
381 scrPtr = rewriter.
create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
383 int64_t numElements = adaptor.getNumElements().getZExtValue();
384 int64_t sizeInBytes =
385 (dstMemrefType.getElementTypeBitWidth() * numElements) / 8;
389 sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr();
390 rewriter.
create<NVVM::CpAsyncOp>(
395 op->getLoc(), IntegerType::get(op.getContext(), 32),
402 struct NVGPUAsyncCreateGroupLowering
408 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor,
410 rewriter.
create<NVVM::CpAsyncCommitGroupOp>(op.getLoc());
413 op->getLoc(), IntegerType::get(op.getContext(), 32),
420 struct NVGPUAsyncWaitLowering
426 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor,
429 int32_t numGroups = adaptor.getNumGroups() ? *adaptor.getNumGroups() : 0;
430 rewriter.
create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups);
440 patterns.
add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
441 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
446 return std::make_unique<ConvertNVGPUToNVVMPass>();
static Type inferIntrinsicResultType(Type vectorResultType)
Returns the type for the intrinsic given the vectorResultType of the gpu.mma.sync operation...
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
MLIRContext * getContext() const
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
std::unique_ptr< Pass > createConvertNVGPUToNVVMPass()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Derived class that automatically populates legalization information for different LLVM ops...
static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter)
Convert the SSA result of the NVVM intrinsic nvvm.mma.sync (which is always an LLVM struct) into a fr...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
IntegerAttr getI32IntegerAttr(int32_t value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
IntegerAttr getI64IntegerAttr(int64_t value)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
IntegerType getIntegerType(unsigned width)
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Device-side token storage type. There is only one type of device-side token.
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getType() const
Return the type of this attribute.
void addConversion(FnT &&callback)
Register a conversion function.
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
This class implements a pattern rewriter for use with ConversionPatterns.
Global memory space identifier.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static SmallVector< Value > unpackOperandVector(RewriterBase &rewriter, Location loc, Value operand, NVVM::MMATypes operandPtxType)
The gpu.mma.sync converter below expects matrix fragment operands to be given as 2D vectors where the...