20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/DebugLog.h"
25 #define GEN_PASS_DEF_XEGPUBLOCKING
26 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "xegpu-blocking"
42 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
46 auto hasIdenticalVectorTypes = [](
ValueRange values) {
48 return llvm::all_of(types, [&](
Type type) {
49 return isa<VectorType>(type) && type == types.front();
55 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
56 LDBG() <<
"skip unrealized conversion cast op not emulating pack/unpack.";
60 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
62 if (inputs.size() > 1 && outputs.size() == 1) {
66 builder, castOp.getLoc(), inputs, shape);
67 castOp->replaceAllUsesWith(
ValueRange(result));
69 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
73 builder, castOp.getLoc(), inputs[0], tileShape);
74 castOp->replaceAllUsesWith(results);
83 struct ConvertLayoutOpPattern
86 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
88 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
89 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
90 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
91 targetLayout.getEffectiveInstDataAsInt().empty())
94 inputLayout = inputLayout.dropInstData();
95 targetLayout = targetLayout.dropInstData();
96 auto newOp = rewriter.
createOrFold<xegpu::ConvertLayoutOp>(
97 op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout);
111 class XeGPUBlockingPass final
112 :
public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
114 void runOnOperation()
override;
120 template <
typename T,
121 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
122 std::is_same_v<T, OpResult>>>
123 std::optional<SmallVector<int64_t>>
136 template <
typename T,
typename>
137 std::optional<SmallVector<int64_t>>
140 if constexpr (std::is_same_v<T, OpOperand>)
141 value = operandOrResult.get();
143 value = (
Value)operandOrResult;
145 xegpu::DistributeLayoutAttr layout =
147 if (layout && layout.isForSubgroup()) {
148 if (!layout.getEffectiveInstDataAsInt().empty())
149 return layout.getEffectiveInstDataAsInt();
151 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
152 return llvm::to_vector(type.getShape());
154 LDBG() <<
"failed to getTileShape for: " << value;
158 std::optional<SmallVector<int64_t>>
160 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
161 xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
163 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
164 xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
166 if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
169 if (isa<xegpu::DpasOp>(op)) {
170 std::optional<SmallVector<int64_t>> aTile =
172 std::optional<SmallVector<int64_t>> bTile =
175 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
179 if ((*aTile)[1] != (*bTile)[0])
184 std::optional<SmallVector<int64_t>> cTile =
186 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
187 if (!cTile || !llvm::equal(*cTile, expectedCTile))
197 if (isa<vector::MultiDimReductionOp>(op))
200 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
206 bool XeGPUBlockingPass::needsUnroll(
Operation *op)
const {
208 bool hasWgLayoutOperands =
210 xegpu::DistributeLayoutAttr layout =
211 xegpu::getDistributeLayoutAttr(opr);
212 return layout && layout.isForWorkgroup();
214 bool hasWgLayoutResults =
216 xegpu::DistributeLayoutAttr layout =
217 xegpu::getDistributeLayoutAttr(result);
218 return layout && layout.isForWorkgroup();
220 if (hasWgLayoutOperands || hasWgLayoutResults) {
221 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
227 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
228 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
229 return layout && !layout.getEffectiveInstDataAsInt().empty();
231 auto shapedType = dyn_cast<ShapedType>(valTy);
232 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
235 bool hasUnrollableOperands =
237 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
238 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
240 bool hasUnrollableResults =
242 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
243 return tileShape.has_value() && isUnrollable(result, *tileShape);
245 return hasUnrollableOperands || hasUnrollableResults;
248 void XeGPUBlockingPass::runOnOperation() {
259 xegpu::LayoutAttr layout) {
262 if (layout && layout.getInstData()) {
264 tileShape = llvm::to_vector_of<int64_t>(instData.
asArrayRef());
267 return std::make_pair(tileShape, count);
274 [&](RankedTensorType type,
276 Type elemTy = type.getElementType();
280 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
281 if (layout && layout.isForWorkgroup())
286 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
288 result.append(count, newTy);
292 [&](xegpu::TensorDescType type,
294 Type elemTy = type.getElementType();
297 xegpu::LayoutAttr layout = type.getLayoutAttr();
298 if (layout && layout.isForWorkgroup())
303 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
306 layout = layout.dropInstData();
309 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
310 result.append(count, newTy);
318 [&](
Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
323 Type elemTy = type.getElementType();
326 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
328 Attribute encoding = tdescTy.getEncoding();
331 if (tdescTy.isScattered()) {
332 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
335 int64_t blockedChunkSize = chunkSize;
336 auto instData = tdescTy.getLayoutAttr().getInstData();
337 if (!instData.empty())
338 blockedChunkSize = instData.asArrayRef().back();
342 ctx, tdescTy.getMemorySpace(), blockedChunkSize);
344 encoding = newEncoding;
350 tdescTy.getLayoutAttr().dropInstData());
352 newTy = type.clone(tileShape, elemTy);
355 std::optional<SmallVector<int64_t>> ratio =
357 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
362 patterns.add<ConvertLayoutOpPattern>(ctx);
368 vector::populateVectorUnrollPatterns(
patterns, vectorOptions);
383 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
385 if (!isa<LoopLikeOpInterface>(op))
391 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
392 resolveUnrealizedConversionCastOp(castOp);
static MLIRContext * getContext(OpFoldResult val)
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
OpOperand & getOpOperand(unsigned idx)
bool hasAttrOfType(NameT &&name)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttrs(Operation *op, function_ref< DistributeLayoutAttr(Value)> getLayoutImpl)
Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Options to control the XeGPU unrolling.