20 using namespace mlir::amx;
31 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
32 assert(llvm::isPowerOf2_64(width) && width >= 8);
33 unsigned bytes = width >> 3;
36 return std::make_pair(
37 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
38 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
47 if (mType.getRank() < 2)
49 int64_t preLast = mType.getRank() - 2;
51 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
52 assert(llvm::isPowerOf2_64(width) && width >= 8);
53 unsigned bytes = width >> 3;
56 if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
58 if (strides[preLast] == ShapedType::kDynamic) {
62 Value scale = rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
64 .
create<LLVM::MulOp>(loc, llvmInt64Type, scale,
65 memrefDescriptor.stride(rewriter, loc, preLast))
70 return rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
77 matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
79 amx::TileType tType = op.getTileType();
81 std::pair<Value, Value> tsz =
82 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
95 matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
97 MemRefType mType = op.getMemRefType();
98 amx::TileType tType = op.getTileType();
100 std::pair<Value, Value> tsz =
101 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
103 auto stride = getStride(rewriter, *getTypeConverter(), mType,
104 adaptor.getBase(), op.getLoc());
108 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
109 adaptor.getIndices(), rewriter);
112 op, resType, tsz.first, tsz.second, ptr, stride.value());
121 matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
123 MemRefType mType = op.getMemRefType();
124 amx::TileType tType = op.getTileType();
126 std::pair<Value, Value> tsz =
127 getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
129 auto stride = getStride(rewriter, *getTypeConverter(), mType,
130 adaptor.getBase(), op.getLoc());
134 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
135 adaptor.getIndices(), rewriter);
137 op, tsz.first, tsz.second, ptr, stride.value(), adaptor.getVal());
145 matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
147 amx::TileType aType = op.getLhsTileType();
148 amx::TileType bType = op.getRhsTileType();
149 amx::TileType cType = op.getTileType();
151 std::pair<Value, Value> tsza =
152 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
153 std::pair<Value, Value> tszb =
154 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
157 if (aType.getElementType().isBF16())
159 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
160 adaptor.getLhs(), adaptor.getRhs());
161 else if (aType.getElementType().isF16())
163 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
164 adaptor.getLhs(), adaptor.getRhs());
166 llvm_unreachable(
"Unexpected element type for amx.mulf");
174 matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
176 amx::TileType aType = op.getLhsTileType();
177 amx::TileType bType = op.getRhsTileType();
178 amx::TileType cType = op.getTileType();
180 std::pair<Value, Value> tsza =
181 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
182 std::pair<Value, Value> tszb =
183 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
186 bool zexta = op.getIsZextLhs();
187 bool zextb = op.getIsZextRhs();
190 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
191 adaptor.getLhs(), adaptor.getRhs());
192 else if (zexta && !zextb)
194 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
195 adaptor.getLhs(), adaptor.getRhs());
196 else if (!zexta && zextb)
198 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
199 adaptor.getLhs(), adaptor.getRhs());
202 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
203 adaptor.getLhs(), adaptor.getRhs());
212 patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
213 TileMulFConversion, TileMulIConversion>(converter);
220 target.
addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
221 x86_amx_tdpbf16ps, x86_amx_tdpfp16ps, x86_amx_tdpbssd,
222 x86_amx_tdpbsud, x86_amx_tdpbusd, x86_amx_tdpbuud>();
223 target.
addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
232 void populateConvertToLLVMConversionPatterns(
242 dialect->addInterfaces<AMXToLLVMDialectInterface>();
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
void registerConvertAMXToLLVMInterface(DialectRegistry ®istry)
Register LLVM conversion interface for AMX dialect.