23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/Support/DebugLog.h"
29#define GEN_PASS_DEF_XEGPUBLOCKING
30#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
34#define DEBUG_TYPE "xegpu-blocking"
48class XeGPUBlockingPass final
49 :
public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
51 void runOnOperation()
override;
58 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
59 std::is_same_v<T, OpResult>>>
60 std::optional<SmallVector<int64_t>>
73template <
typename T,
typename>
74std::optional<SmallVector<int64_t>>
75XeGPUBlockingPass::getTileShape(
const T &operandOrResult)
const {
77 if constexpr (std::is_same_v<T, OpOperand>) {
78 value = operandOrResult.get();
80 value = (Value)operandOrResult;
83 xegpu::DistributeLayoutAttr layout =
85 if (layout && layout.isForSubgroup()) {
86 if (!layout.getEffectiveInstDataAsInt().empty()) {
87 SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
90 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
91 return llvm::to_vector(type.getShape());
93 LDBG() <<
"failed to getTileShape for: " << value;
97std::optional<SmallVector<int64_t>>
98XeGPUBlockingPass::getTileShape(Operation *op)
const {
99 if (isa<xegpu::CreateNdDescOp, xegpu::LoadMatrixOp>(op))
101 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
102 xegpu::StoreMatrixOp>(op))
104 if (isa<xegpu::StoreNdOp>(op))
107 if (isa<xegpu::LoadGatherOp>(op))
110 if (
auto convertLayoutOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
112 convertLayoutOp.getInputLayout().getEffectiveInstDataAsInt();
113 auto targetInstData =
114 convertLayoutOp.getTargetLayout().getEffectiveInstDataAsInt();
117 return inputInstData;
119 return targetInstData;
122 if (isa<xegpu::StoreScatterOp>(op))
126 auto validateABTiles = [&](Operation *op)
127 -> std::optional<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> {
128 std::optional<SmallVector<int64_t>> aTile =
130 std::optional<SmallVector<int64_t>> bTile =
133 if (!aTile || aTile->size() < 2 || !bTile || bTile->size() < 2)
137 int64_t aBatchRank = aTile->size() - 2;
138 int64_t bBatchRank = bTile->size() - 2;
139 if (aBatchRank != bBatchRank)
143 for (int64_t i = 0; i < aBatchRank; ++i) {
144 if ((*aTile)[i] != (*bTile)[i])
150 if ((*aTile).back() != (*bTile)[bBatchRank])
153 return std::make_pair(*aTile, *bTile);
157 auto validateCTile = [&](Operation *op,
unsigned cOperandIdx,
158 const SmallVector<int64_t> &aTile,
159 const SmallVector<int64_t> &bTile) ->
bool {
163 std::optional<SmallVector<int64_t>> cTile =
168 int64_t aBatchRank = aTile.size() - 2;
169 SmallVector<int64_t> expectedCTile(aTile.begin(),
170 aTile.begin() + aBatchRank);
171 expectedCTile.push_back(aTile[aBatchRank]);
172 expectedCTile.push_back(bTile.back());
173 if (!llvm::equal(*cTile, expectedCTile))
179 auto validateScaleATile =
180 [&](Operation *op,
unsigned scaleAOperandIdx,
181 const SmallVector<int64_t> &aTile) -> std::optional<int64_t> {
182 std::optional<SmallVector<int64_t>> aScaleTile =
185 if (!aScaleTile || aScaleTile->size() < 2)
190 int64_t scaleRank = aScaleTile->size();
191 int64_t aBatchRank = aTile.size() - 2;
192 if ((*aScaleTile)[scaleRank - 2] != aTile[aBatchRank])
196 return aScaleTile->back();
200 auto validateScaleBTile =
201 [&](Operation *op,
unsigned scaleBOperandIdx,
202 const SmallVector<int64_t> &bTile) -> std::optional<int64_t> {
203 std::optional<SmallVector<int64_t>> bScaleTile =
206 if (!bScaleTile || bScaleTile->size() < 2)
211 if (bScaleTile->back() != bTile.back())
215 int64_t scaleRank = bScaleTile->size();
216 return (*bScaleTile)[scaleRank - 2];
219 if (isa<xegpu::DpasOp>(op)) {
220 auto abTiles = validateABTiles(op);
224 auto [aTile, bTile] = *abTiles;
227 if (!validateCTile(op, 2, aTile, bTile))
231 int64_t aBatchRank = aTile.size() - 2;
232 SmallVector<int64_t> tileShape(aTile.begin(), aTile.begin() + aBatchRank);
233 tileShape.push_back(aTile[aBatchRank]);
234 tileShape.push_back(aTile[aBatchRank + 1]);
235 tileShape.push_back(bTile.back());
239 if (
auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
240 auto abTiles = validateABTiles(op);
244 auto [aTile, bTile] = *abTiles;
247 if (dpasMxOp.getAcc()) {
248 unsigned accOperandIdx = 2;
249 if (!validateCTile(op, accOperandIdx, aTile, bTile))
254 int64_t kScaleFactor = 1;
255 std::optional<int64_t> scaleAFactor;
256 std::optional<int64_t> scaleBFactor;
258 if (dpasMxOp.getScaleA()) {
259 unsigned scaleAOperandIdx = 2 + (dpasMxOp.getAcc() ? 1 : 0);
260 scaleAFactor = validateScaleATile(op, scaleAOperandIdx, aTile);
265 if (dpasMxOp.getScaleB()) {
266 unsigned scaleBOperandIdx =
267 2 + (dpasMxOp.getAcc() ? 1 : 0) + (dpasMxOp.getScaleA() ? 1 : 0);
268 scaleBFactor = validateScaleBTile(op, scaleBOperandIdx, bTile);
274 if (scaleAFactor && scaleBFactor) {
275 if (*scaleAFactor != *scaleBFactor)
277 kScaleFactor = *scaleAFactor;
278 }
else if (scaleAFactor) {
279 kScaleFactor = *scaleAFactor;
280 }
else if (scaleBFactor) {
281 kScaleFactor = *scaleBFactor;
285 int64_t aBatchRank = aTile.size() - 2;
286 SmallVector<int64_t> tileShape(aTile.begin(), aTile.begin() + aBatchRank);
287 tileShape.push_back(aTile[aBatchRank]);
288 tileShape.push_back(aTile[aBatchRank + 1]);
289 tileShape.push_back(bTile.back());
290 tileShape.push_back(kScaleFactor);
297 if (isa<vector::MultiDimReductionOp>(op))
300 if (isa<vector::TransposeOp, vector::BroadcastOp, vector::StepOp,
301 vector::ShapeCastOp, vector::ConstantMaskOp, vector::CreateMaskOp,
302 vector::BitCastOp, vector::InterleaveOp, vector::DeinterleaveOp>(op))
308bool XeGPUBlockingPass::needsUnroll(Operation *op)
const {
310 bool hasWgLayoutOperands =
312 xegpu::DistributeLayoutAttr layout =
313 xegpu::getDistributeLayoutAttr(opr);
314 return layout && layout.isForWorkgroup();
316 bool hasWgLayoutResults =
318 xegpu::DistributeLayoutAttr layout =
319 xegpu::getDistributeLayoutAttr(result);
320 return layout && layout.isForWorkgroup();
322 if (hasWgLayoutOperands || hasWgLayoutResults) {
323 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
327 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
329 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
330 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
331 return layout && !layout.getEffectiveInstDataAsInt().empty();
333 auto shapedType = dyn_cast<ShapedType>(valTy);
334 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
337 bool hasUnrollableOperands =
339 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
340 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
342 bool hasUnrollableResults =
344 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
345 return tileShape.has_value() && isUnrollable(result, *tileShape);
348 bool isConvertLayoutWithInstData =
false;
349 if (
auto convertLayoutOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
350 auto targettLayout = convertLayoutOp.getTargetLayout();
351 if (targettLayout && !targettLayout.getEffectiveInstDataAsInt().empty()) {
352 isConvertLayoutWithInstData =
true;
355 return hasUnrollableOperands || hasUnrollableResults ||
356 isConvertLayoutWithInstData;
359void XeGPUBlockingPass::runOnOperation() {
361 Operation *op = getOperation();
368 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
369 xegpu::DistributeLayoutAttr layout) {
371 SmallVector<int64_t> tileShape(shape);
372 if (layout && !layout.getEffectiveInstDataAsInt().empty()) {
373 tileShape = layout.getEffectiveInstDataAsInt();
376 assert(count >= 1 &&
"count must be at least 1");
377 return std::make_pair(tileShape, count);
382 llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
384 [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
387 TypeConverter converter;
388 converter.addConversion([](Type type) -> Type {
return type; });
391 converter.addConversion(
392 [&](xegpu::TensorDescType type,
393 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
394 Type elemTy = type.getElementType();
395 ArrayRef<int64_t> shape = type.getShape();
397 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
398 if (layout && layout.isForWorkgroup())
402 SmallVector<int64_t> subShape;
403 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
406 layout = layout.dropInstData();
408 auto newTy = xegpu::TensorDescType::get(
409 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
410 result.append(count, newTy);
416 auto getSubShapeAndCount = [&](VectorType vecTy,
417 xegpu::DistributeLayoutAttr layout)
418 -> std::pair<SmallVector<int64_t>,
int> {
419 return getTileShapeAndCount(vecTy.getShape(), layout);
424 std::move(loopArgTypes));
432 op->
walk([](Operation *loopOp) {
433 if (!isa<scf::ForOp, scf::WhileOp, scf::ConditionOp>(loopOp))
435 SmallVector<StringRef> toRemove;
436 for (
const NamedAttribute &attr : loopOp->
getAttrs()) {
437 StringRef name = attr.getName().strref();
438 if (name.starts_with(
"layout_operand_") ||
439 name.starts_with(
"layout_result_"))
440 toRemove.push_back(name);
442 for (StringRef name : toRemove)
448 auto materializeCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
449 Location loc) -> Value {
450 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
453 converter.addSourceMaterialization(materializeCast);
454 converter.addTargetMaterialization(materializeCast);
457 converter.addTargetMaterialization(
458 [](mlir::OpBuilder &builder, mlir::TypeRange types,
459 mlir::ValueRange inputs, mlir::Location loc) -> SmallVector<Value> {
461 UnrealizedConversionCastOp::create(builder, loc, types, inputs);
462 return SmallVector<Value>(castOp.getResults());
465 ConversionTarget
target(*ctx);
466 target.addLegalOp<UnrealizedConversionCastOp>();
467 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
469 RewritePatternSet scfPatterns(ctx);
472 if (
failed(applyPartialConversion(op,
target, std::move(scfPatterns))))
473 return signalPassFailure();
481 [&](Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
485 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
486 Type elemTy = type.getElementType();
488 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
490 Attribute encoding = tdescTy.getEncoding();
492 xegpu::TensorDescType newTy =
493 xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
494 tdescTy.getLayoutAttr().dropInstData());
496 ArrayRef<int64_t> shape = type.getShape();
499 return SmallVector<Type>(batchCount, newTy);
501 Type newTy = VectorType::get(tileShape, elemTy);
503 std::optional<SmallVector<int64_t>> ratio =
505 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
509 RewritePatternSet patterns(ctx);
510 vector::UnrollVectorOptions vectorOptions;
514 vector::populateVectorUnrollPatterns(patterns, vectorOptions);
523 op->
walk([](Operation *op) {
534 if (
auto layout = op->
getAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
536 if (!isa<LoopLikeOpInterface>(op))
543 SmallVector<NamedAttribute> newAttrs =
555 RewritePatternSet emptyPatterns(ctx);
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumResults()
Return the number of results held by this operation.
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
DenseMap< Value, SmallVector< Type > > precomputeLoopBlockArgTypes(Operation *topLevelOp, SubShapeAndCountFn getSubShapeAndCount)
Pre-computes distributed VectorType mappings for every value carried through an SCF loop under topLev...
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void addVectorTypeConversion(TypeConverter &converter, SubShapeAndCountFn getSubShapeAndCount, DenseMap< Value, SmallVector< Type > > loopArgTypes)
Adds a context-aware VectorType conversion to converter (1:1 shape-changing or 1:N,...
void cleanupUnrealizedConversionCasts(Operation *root, const llvm::SmallSetVector< UnrealizedConversionCastOp, 8 > &existingCasts)
Cleans up UnrealizedConversionCastOps inserted during SCF structural type conversion and/or XeGPU unr...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)