19 using namespace mlir::amx;
30 unsigned width = vType.getElementType().getIntOrFloatBitWidth();
31 assert(llvm::isPowerOf2_64(width) && width >= 8);
32 unsigned bytes = width >> 3;
35 return std::make_pair(
36 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
37 rewriter.
create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
42 if (mType.getRank() < 2)
44 int64_t last = mType.getRank() - 1;
57 assert(mType.getRank() >= 2);
58 int64_t last = mType.getRank() - 1;
60 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
61 assert(llvm::isPowerOf2_64(width) && width >= 8);
62 unsigned bytes = width >> 3;
63 if (mType.isDynamicDim(last)) {
67 Value scale = rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68 return rewriter.
create<LLVM::MulOp>(
69 loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
73 return rewriter.
create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
79 matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
81 VectorType vType = op.getVectorType();
83 std::pair<Value, Value> tsz =
84 getTileSizes(rewriter, *getTypeConverter(), vType, op.
getLoc());
97 matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
99 MemRefType mType = op.getMemRefType();
100 VectorType vType = op.getVectorType();
102 std::pair<Value, Value> tsz =
103 getTileSizes(rewriter, *getTypeConverter(), vType, op.
getLoc());
105 if (
failed(verifyStride(mType)))
107 Value stride = getStride(rewriter, *getTypeConverter(), mType,
108 adaptor.getBase(), op.
getLoc());
110 Value ptr = getStridedElementPtr(op.
getLoc(), mType, adaptor.getBase(),
111 adaptor.getIndices(), rewriter);
114 op, resType, tsz.first, tsz.second, ptr, stride);
123 matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
125 MemRefType mType = op.getMemRefType();
126 VectorType vType = op.getVectorType();
128 std::pair<Value, Value> tsz =
129 getTileSizes(rewriter, *getTypeConverter(), vType, op.
getLoc());
131 if (
failed(verifyStride(mType)))
133 Value stride = getStride(rewriter, *getTypeConverter(), mType,
134 adaptor.getBase(), op.
getLoc());
136 Value ptr = getStridedElementPtr(op.
getLoc(), mType, adaptor.getBase(),
137 adaptor.getIndices(), rewriter);
139 op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
147 matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
149 VectorType aType = op.getLhsVectorType();
150 VectorType bType = op.getRhsVectorType();
151 VectorType cType = op.getVectorType();
153 std::pair<Value, Value> tsza =
154 getTileSizes(rewriter, *getTypeConverter(), aType, op.
getLoc());
155 std::pair<Value, Value> tszb =
156 getTileSizes(rewriter, *getTypeConverter(), bType, op.
getLoc());
160 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
161 adaptor.getLhs(), adaptor.getRhs());
169 matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
171 VectorType aType = op.getLhsVectorType();
172 VectorType bType = op.getRhsVectorType();
173 VectorType cType = op.getVectorType();
175 std::pair<Value, Value> tsza =
176 getTileSizes(rewriter, *getTypeConverter(), aType, op.
getLoc());
177 std::pair<Value, Value> tszb =
178 getTileSizes(rewriter, *getTypeConverter(), bType, op.
getLoc());
181 bool zexta = op.getIsZextLhs();
182 bool zextb = op.getIsZextRhs();
185 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
186 adaptor.getLhs(), adaptor.getRhs());
187 else if (zexta && !zextb)
189 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
190 adaptor.getLhs(), adaptor.getRhs());
191 else if (!zexta && zextb)
193 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
194 adaptor.getLhs(), adaptor.getRhs());
197 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
198 adaptor.getLhs(), adaptor.getRhs());
207 patterns.
add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
208 TileMulFConversion, TileMulIConversion>(converter);
212 target.
addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
213 x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
214 x86_amx_tdpbusd, x86_amx_tdpbuud>();
215 target.
addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
MLIRContext & getContext() const
Returns the MLIR context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.