20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/DebugLog.h"
25 #define GEN_PASS_DEF_XEGPUBLOCKING
26 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "xegpu-blocking"
42 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
46 auto hasIdenticalVectorTypes = [](
ValueRange values) {
48 return llvm::all_of(types, [&](
Type type) {
49 return isa<VectorType>(type) && type == types.front();
55 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
56 LDBG() <<
"skip unrealized conversion cast op not emulating pack/unpack.";
60 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
62 if (inputs.size() > 1 && outputs.size() == 1) {
66 builder, castOp.getLoc(), inputs, shape);
67 castOp->replaceAllUsesWith(
ValueRange(result));
69 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
73 builder, castOp.getLoc(), inputs[0], tileShape);
74 castOp->replaceAllUsesWith(results);
83 struct ConvertLayoutOpPattern
86 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
88 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
89 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
90 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
91 targetLayout.getEffectiveInstDataAsInt().empty())
94 inputLayout = inputLayout.dropInstData();
95 targetLayout = targetLayout.dropInstData();
96 auto newOp = rewriter.
createOrFold<xegpu::ConvertLayoutOp>(
97 op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout);
111 class XeGPUBlockingPass final
112 :
public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
114 void runOnOperation()
override;
120 template <
typename T,
121 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
122 std::is_same_v<T, OpResult>>>
123 std::optional<SmallVector<int64_t>>
136 template <
typename T,
typename>
137 std::optional<SmallVector<int64_t>>
140 if constexpr (std::is_same_v<T, OpOperand>)
141 value = operandOrResult.get();
143 value = (
Value)operandOrResult;
145 xegpu::DistributeLayoutAttr layout =
147 if (layout && layout.isForSubgroup()) {
148 if (!layout.getEffectiveInstDataAsInt().empty()) {
158 bool skipLeadingUnitDimRemoval =
160 (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
161 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
162 if (!skipLeadingUnitDimRemoval) {
163 auto it = llvm::find_if(instData, [](
auto val) {
return val != 1; });
164 instData.erase(instData.begin(), it);
169 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
170 return llvm::to_vector(type.getShape());
172 LDBG() <<
"failed to getTileShape for: " << value;
176 std::optional<SmallVector<int64_t>>
178 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
179 xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
181 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
182 xegpu::StoreMatrixOp>(op))
184 if (isa<xegpu::StoreNdOp>(op))
188 if (
auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
189 if (loadGatherOp.getOffsets())
195 if (
auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
197 ? storeScatterOp->getOpOperand(0)
198 : storeScatterOp->getOpOperand(1));
200 if (isa<xegpu::DpasOp>(op)) {
201 std::optional<SmallVector<int64_t>> aTile =
203 std::optional<SmallVector<int64_t>> bTile =
206 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
210 if ((*aTile)[1] != (*bTile)[0])
215 std::optional<SmallVector<int64_t>> cTile =
217 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
218 if (!cTile || !llvm::equal(*cTile, expectedCTile))
228 if (isa<vector::MultiDimReductionOp>(op))
231 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
237 bool XeGPUBlockingPass::needsUnroll(
Operation *op)
const {
239 bool hasWgLayoutOperands =
241 xegpu::DistributeLayoutAttr layout =
242 xegpu::getDistributeLayoutAttr(opr);
243 return layout && layout.isForWorkgroup();
245 bool hasWgLayoutResults =
247 xegpu::DistributeLayoutAttr layout =
248 xegpu::getDistributeLayoutAttr(result);
249 return layout && layout.isForWorkgroup();
251 if (hasWgLayoutOperands || hasWgLayoutResults) {
252 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
258 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
259 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
260 return layout && !layout.getEffectiveInstDataAsInt().empty();
262 auto shapedType = dyn_cast<ShapedType>(valTy);
263 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
266 bool hasUnrollableOperands =
268 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
269 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
271 bool hasUnrollableResults =
273 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
274 return tileShape.has_value() && isUnrollable(result, *tileShape);
276 return hasUnrollableOperands || hasUnrollableResults;
279 void XeGPUBlockingPass::runOnOperation() {
290 xegpu::LayoutAttr layout) {
293 if (layout && layout.getInstData()) {
295 tileShape = llvm::to_vector_of<int64_t>(instData.
asArrayRef());
298 return std::make_pair(tileShape, count);
305 [&](RankedTensorType type,
307 Type elemTy = type.getElementType();
311 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
312 if (layout && layout.isForWorkgroup())
317 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
319 result.append(count, newTy);
323 [&](xegpu::TensorDescType type,
325 Type elemTy = type.getElementType();
328 xegpu::LayoutAttr layout = type.getLayoutAttr();
329 if (layout && layout.isForWorkgroup())
334 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
337 layout = layout.dropInstData();
340 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
341 result.append(count, newTy);
351 vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns);
356 [&](
Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
361 bool returnSingleType =
false) {
362 Type elemTy = type.getElementType();
365 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
367 Attribute encoding = tdescTy.getEncoding();
370 if (tdescTy.isScattered()) {
371 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
374 int64_t blockedChunkSize = chunkSize;
375 auto instData = tdescTy.getLayoutAttr().getInstData();
376 if (!instData.empty())
377 blockedChunkSize = instData.asArrayRef().back();
381 ctx, tdescTy.getMemorySpace(), blockedChunkSize);
382 encoding = newEncoding;
388 tdescTy.getLayoutAttr().dropInstData());
393 if (returnSingleType)
395 std::optional<SmallVector<int64_t>> ratio =
397 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
402 patterns.add<ConvertLayoutOpPattern>(ctx);
408 vector::populateVectorUnrollPatterns(
patterns, vectorOptions);
423 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
425 if (!isa<LoopLikeOpInterface>(op))
431 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
432 resolveUnrealizedConversionCastOp(castOp);
static MLIRContext * getContext(OpFoldResult val)
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
OpOperand & getOpOperand(unsigned idx)
bool hasAttrOfType(NameT &&name)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ArrayRef< T > asArrayRef() const
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttrs(Operation *op, function_ref< DistributeLayoutAttr(Value)> getLayoutImpl)
Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Options to control the XeGPU unrolling.