21 #include "llvm/ADT/STLExtras.h"
25 #define GEN_PASS_DEF_XEGPUBLOCKING
26 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "xegpu-blocking"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
44 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
48 auto hasIdenticalVectorTypes = [](
ValueRange values) {
50 return llvm::all_of(types, [&](
Type type) {
51 return isa<VectorType>(type) && type == types.front();
57 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
58 LDBG(
"skip unrealized conversion cast op not emulating pack/unpack.");
62 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
64 if (inputs.size() > 1 && outputs.size() == 1) {
68 builder, castOp.getLoc(), inputs, shape);
69 castOp->replaceAllUsesWith(
ValueRange(result));
71 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
75 builder, castOp.getLoc(), inputs[0], tileShape);
76 castOp->replaceAllUsesWith(results);
89 class XeGPUBlockingPass final
90 :
public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
92 void runOnOperation()
override;
99 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
100 std::is_same_v<T, OpResult>>>
101 std::optional<SmallVector<int64_t>>
114 template <
typename T,
typename>
115 std::optional<SmallVector<int64_t>>
118 if constexpr (std::is_same_v<T, OpOperand>)
119 value = operandOrResult.get();
121 value = (
Value)operandOrResult;
124 if (layout && layout.isSgLayout()) {
125 if (
auto inst_data = layout.getInstData())
126 return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
128 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
129 return llvm::to_vector(type.getShape());
131 LDBG(
"failed to getTileShape for: " << value);
135 std::optional<SmallVector<int64_t>>
137 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
138 xegpu::UpdateOffsetOp>(op))
140 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
141 xegpu::LoadGatherOp>(op))
143 if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
146 if (isa<xegpu::DpasOp>(op)) {
147 std::optional<SmallVector<int64_t>> aTile =
149 std::optional<SmallVector<int64_t>> bTile =
152 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
156 if ((*aTile)[1] != (*bTile)[0])
161 std::optional<SmallVector<int64_t>> cTile =
163 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
164 if (!cTile || !llvm::equal(*cTile, expectedCTile))
174 if (isa<vector::MultiDimReductionOp>(op))
177 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
183 bool XeGPUBlockingPass::needsUnroll(
Operation *op)
const {
185 bool hasWgLayoutOperands =
187 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
188 return layout && layout.isWgLayout();
190 bool hasWgLayoutResults =
192 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
193 return layout && layout.isWgLayout();
195 if (hasWgLayoutOperands || hasWgLayoutResults) {
196 LDBG(
"skip unrolling for op with workgroup level layout: " << *op);
202 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
203 xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
204 return layout && layout.getInstData();
206 auto shapedType = dyn_cast<ShapedType>(valTy);
207 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
210 bool hasUnrollableOperands =
212 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
213 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
215 bool hasUnrollableResults =
217 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
218 return tileShape.has_value() && isUnrollable(result, *tileShape);
220 return hasUnrollableOperands || hasUnrollableResults;
223 void XeGPUBlockingPass::runOnOperation() {
233 xegpu::LayoutAttr layout) {
236 if (layout && layout.getInstData()) {
238 tileShape = llvm::to_vector_of<int64_t>(instData.
asArrayRef());
241 return std::make_pair(tileShape, count);
248 [&](RankedTensorType type,
250 Type elemTy = type.getElementType();
254 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
255 if (layout && layout.isWgLayout())
260 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
262 result.append(count, newTy);
266 [&](xegpu::TensorDescType type,
268 Type elemTy = type.getElementType();
271 xegpu::LayoutAttr layout = type.getLayoutAttr();
272 if (layout && layout.isWgLayout())
277 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
280 layout = layout.dropInstData();
283 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
284 result.append(count, newTy);
292 [&](
Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
297 Type elemTy = type.getElementType();
300 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
302 Attribute encoding = tdescTy.getEncoding();
305 if (tdescTy.isScattered()) {
307 llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
308 int64_t chunkSize = scatterAttr.getChunkSize().getInt();
311 int64_t blockedChunkSize = chunkSize;
312 auto instData = tdescTy.getLayoutAttr().getInstData();
313 if (!instData.empty())
314 blockedChunkSize = instData.asArrayRef().back();
318 ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
320 encoding = newEncoding;
326 tdescTy.getLayoutAttr().dropInstData());
328 newTy = type.clone(tileShape, elemTy);
331 std::optional<SmallVector<int64_t>> ratio =
333 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
343 vector::populateVectorUnrollPatterns(
patterns, vectorOptions);
358 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
360 if (!isa<LoopLikeOpInterface>(op))
366 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
367 resolveUnrealizedConversionCastOp(castOp);
static MLIRContext * getContext(OpFoldResult val)
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
OpOperand & getOpOperand(unsigned idx)
bool hasAttrOfType(NameT &&name)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
LayoutAttr getLayoutAttr(const Value value)
Retrieves the LayoutAttr associated with a given Value.
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)
Set the LayoutAttr for each OpOperand and OpResult of the given operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Options to control the XeGPU unrolling.