22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/MathExtras.h"
37 precision(precision) {}
39 LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
43 if (op->
hasAttr(op.getTf32EnabledAttrName()) ||
44 !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
48 return emitError(location,
"MmaSync F32-to-TF32 cannot be lowered with "
49 "unknown precision level");
52 return emitError(location,
"TF32x3 is not supported at the moment "
53 "for nvgpu.mma.sync on f32 datatype");
57 op, [&]() { op.setTf32EnabledAttr(rewriter.
getUnitAttr()); });
73 patterns.
add<MmaSyncF32ToTF32Pattern>(patterns.
getContext(), precision);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
MmaSyncF32Lowering
Rewrites patterns.
void populateMmaSyncF32ToTF32Patterns(RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision=nvgpu::MmaSyncF32Lowering::TF32)
Collect patterns to convert mma.sync on f32 input and rewrite to use tensor cores with user provided ...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...