20 using namespace mlir::amx;
31 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
32 assert(llvm::isPowerOf2_64(width) && width >= 8);
33 unsigned bytes = width >> 3;
36 return std::make_pair(
37 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
38 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
47 if (mType.getRank() < 2)
49 int64_t preLast = mType.getRank() - 2;
51 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
52 assert(llvm::isPowerOf2_64(width) && width >= 8);
53 unsigned bytes = width >> 3;
59 if (strides[preLast] == ShapedType::kDynamic) {
63 Value scale = rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
65 .
create<LLVM::MulOp>(loc, llvmInt64Type, scale,
66 memrefDescriptor.stride(rewriter, loc, preLast))
71 return rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
78 matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
80 amx::TileType tType = op.getTileType();
82 std::pair<Value, Value> tsz =
83 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
96 matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
98 MemRefType mType = op.getMemRefType();
99 amx::TileType tType = op.getTileType();
101 std::pair<Value, Value> tsz =
102 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
104 auto stride = getStride(rewriter, *getTypeConverter(), mType,
105 adaptor.getBase(), op.getLoc());
109 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
110 adaptor.getIndices(), rewriter);
113 op, resType, tsz.first, tsz.second, ptr, stride.value());
122 matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
124 MemRefType mType = op.getMemRefType();
125 amx::TileType tType = op.getTileType();
127 std::pair<Value, Value> tsz =
128 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
130 auto stride = getStride(rewriter, *getTypeConverter(), mType,
131 adaptor.getBase(), op.getLoc());
135 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
136 adaptor.getIndices(), rewriter);
138 op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
146 matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
148 amx::TileType aType = op.getLhsTileType();
149 amx::TileType bType = op.getRhsTileType();
150 amx::TileType cType = op.getTileType();
152 std::pair<Value, Value> tsza =
153 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
154 std::pair<Value, Value> tszb =
155 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
158 if (aType.getElementType().isBF16())
160 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
161 adaptor.getLhs(), adaptor.getRhs());
162 else if (aType.getElementType().isF16())
164 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
165 adaptor.getLhs(), adaptor.getRhs());
167 llvm_unreachable(
"Unexpected element type for amx.mulf");
175 matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
177 amx::TileType aType = op.getLhsTileType();
178 amx::TileType bType = op.getRhsTileType();
179 amx::TileType cType = op.getTileType();
181 std::pair<Value, Value> tsza =
182 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
183 std::pair<Value, Value> tszb =
184 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
187 bool zexta = op.getIsZextLhs();
188 bool zextb = op.getIsZextRhs();
191 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
192 adaptor.getLhs(), adaptor.getRhs());
193 else if (zexta && !zextb)
195 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
196 adaptor.getLhs(), adaptor.getRhs());
197 else if (!zexta && zextb)
199 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
200 adaptor.getLhs(), adaptor.getRhs());
203 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
204 adaptor.getLhs(), adaptor.getRhs());
213 patterns.
add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
214 TileMulFConversion, TileMulIConversion>(converter);
221 target.
addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
222 x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
223 x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
224 target.
addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
233 void populateConvertToLLVMConversionPatterns(
243 dialect->addInterfaces<AMXToLLVMDialectInterface>();
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
void registerConvertAMXToLLVMInterface(DialectRegistry ®istry)
Register LLVM conversion interface for AMX dialect.