21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/DebugLog.h"
26#define GEN_PASS_DEF_XEGPUBLOCKING
27#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31#define DEBUG_TYPE "xegpu-blocking"
43resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
47 auto hasIdenticalVectorTypes = [](
ValueRange values) {
48 auto types = values.getTypes();
49 return llvm::all_of(types, [&](
Type type) {
50 return isa<VectorType>(type) && type == types.front();
56 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
57 LDBG() <<
"skip unrealized conversion cast op not emulating pack/unpack.";
61 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
63 if (inputs.size() > 1 && outputs.size() == 1) {
67 builder, castOp.getLoc(), inputs,
shape);
70 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
74 builder, castOp.getLoc(), inputs[0], tileShape);
75 castOp->replaceAllUsesWith(results);
88class XeGPUBlockingPass final
91 void runOnOperation()
override;
98 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
99 std::is_same_v<T, OpResult>>>
100 std::optional<SmallVector<int64_t>>
113template <
typename T,
typename>
114std::optional<SmallVector<int64_t>>
115XeGPUBlockingPass::getTileShape(
const T &operandOrResult)
const {
117 if constexpr (std::is_same_v<T, OpOperand>) {
118 value = operandOrResult.get();
120 value = (Value)operandOrResult;
123 xegpu::DistributeLayoutAttr layout =
125 if (layout && layout.isForSubgroup()) {
126 if (!layout.getEffectiveInstDataAsInt().empty()) {
127 SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
130 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
131 return llvm::to_vector(type.getShape());
133 LDBG() <<
"failed to getTileShape for: " << value;
137std::optional<SmallVector<int64_t>>
138XeGPUBlockingPass::getTileShape(Operation *op)
const {
139 if (isa<xegpu::CreateNdDescOp, xegpu::LoadMatrixOp>(op))
141 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
142 xegpu::StoreMatrixOp>(op))
144 if (isa<xegpu::StoreNdOp>(op))
147 if (isa<xegpu::LoadGatherOp>(op))
150 if (
auto convertLayoutOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
152 convertLayoutOp.getInputLayout().getEffectiveInstDataAsInt();
153 auto targetInstData =
154 convertLayoutOp.getTargetLayout().getEffectiveInstDataAsInt();
157 return inputInstData;
159 return targetInstData;
162 if (isa<xegpu::StoreScatterOp>(op))
166 auto validateABTiles = [&](Operation *op)
167 -> std::optional<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> {
168 std::optional<SmallVector<int64_t>> aTile =
170 std::optional<SmallVector<int64_t>> bTile =
173 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
177 if ((*aTile)[1] != (*bTile)[0])
180 return std::make_pair(*aTile, *bTile);
184 auto validateCTile = [&](Operation *op,
unsigned cOperandIdx,
185 const SmallVector<int64_t> &aTile,
186 const SmallVector<int64_t> &bTile) ->
bool {
190 std::optional<SmallVector<int64_t>> cTile =
192 int64_t expectedCTile[2] = {aTile[0], bTile[1]};
193 if (!cTile || !llvm::equal(*cTile, expectedCTile))
199 auto validateScaleATile =
200 [&](Operation *op,
unsigned scaleAOperandIdx,
201 const SmallVector<int64_t> &aTile) -> std::optional<int64_t> {
202 std::optional<SmallVector<int64_t>> aScaleTile =
205 if (!aScaleTile || aScaleTile->size() != 2)
210 if ((*aScaleTile)[0] != aTile[0])
214 return (*aScaleTile)[1];
218 auto validateScaleBTile =
219 [&](Operation *op,
unsigned scaleBOperandIdx,
220 const SmallVector<int64_t> &bTile) -> std::optional<int64_t> {
221 std::optional<SmallVector<int64_t>> bScaleTile =
224 if (!bScaleTile || bScaleTile->size() != 2)
229 if ((*bScaleTile)[1] != bTile[1])
233 return (*bScaleTile)[0];
236 if (isa<xegpu::DpasOp>(op)) {
237 auto abTiles = validateABTiles(op);
241 auto [aTile, bTile] = *abTiles;
244 if (!validateCTile(op, 2, aTile, bTile))
247 return SmallVector<int64_t>({aTile[0], aTile[1], bTile[1]});
250 if (
auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
251 auto abTiles = validateABTiles(op);
255 auto [aTile, bTile] = *abTiles;
258 if (dpasMxOp.getAcc()) {
259 unsigned accOperandIdx = 2;
260 if (!validateCTile(op, accOperandIdx, aTile, bTile))
265 int64_t kScaleFactor = 1;
266 std::optional<int64_t> scaleAFactor;
267 std::optional<int64_t> scaleBFactor;
269 if (dpasMxOp.getScaleA()) {
270 unsigned scaleAOperandIdx = 2 + (dpasMxOp.getAcc() ? 1 : 0);
271 scaleAFactor = validateScaleATile(op, scaleAOperandIdx, aTile);
276 if (dpasMxOp.getScaleB()) {
277 unsigned scaleBOperandIdx =
278 2 + (dpasMxOp.getAcc() ? 1 : 0) + (dpasMxOp.getScaleA() ? 1 : 0);
279 scaleBFactor = validateScaleBTile(op, scaleBOperandIdx, bTile);
285 if (scaleAFactor && scaleBFactor) {
286 if (*scaleAFactor != *scaleBFactor)
288 kScaleFactor = *scaleAFactor;
289 }
else if (scaleAFactor) {
290 kScaleFactor = *scaleAFactor;
291 }
else if (scaleBFactor) {
292 kScaleFactor = *scaleBFactor;
295 return SmallVector<int64_t>({aTile[0], aTile[1], bTile[1], kScaleFactor});
301 if (isa<vector::MultiDimReductionOp>(op))
304 if (isa<vector::TransposeOp, vector::BroadcastOp, vector::StepOp,
305 vector::ShapeCastOp, vector::ConstantMaskOp, vector::CreateMaskOp,
306 vector::BitCastOp, vector::InterleaveOp, vector::DeinterleaveOp>(op))
312bool XeGPUBlockingPass::needsUnroll(Operation *op)
const {
314 bool hasWgLayoutOperands =
316 xegpu::DistributeLayoutAttr layout =
317 xegpu::getDistributeLayoutAttr(opr);
318 return layout && layout.isForWorkgroup();
320 bool hasWgLayoutResults =
322 xegpu::DistributeLayoutAttr layout =
323 xegpu::getDistributeLayoutAttr(result);
324 return layout && layout.isForWorkgroup();
326 if (hasWgLayoutOperands || hasWgLayoutResults) {
327 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
331 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
333 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
334 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
335 return layout && !layout.getEffectiveInstDataAsInt().empty();
337 auto shapedType = dyn_cast<ShapedType>(valTy);
338 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
341 bool hasUnrollableOperands =
343 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
344 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
346 bool hasUnrollableResults =
348 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
349 return tileShape.has_value() && isUnrollable(result, *tileShape);
352 bool isConvertLayoutWithInstData =
false;
353 if (
auto convertLayoutOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
354 auto targettLayout = convertLayoutOp.getTargetLayout();
355 if (targettLayout && !targettLayout.getEffectiveInstDataAsInt().empty()) {
356 isConvertLayoutWithInstData =
true;
359 return hasUnrollableOperands || hasUnrollableResults ||
360 isConvertLayoutWithInstData;
363void XeGPUBlockingPass::runOnOperation() {
365 Operation *op = getOperation();
372 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
373 xegpu::DistributeLayoutAttr layout) {
375 SmallVector<int64_t> tileShape(shape);
376 if (layout && !layout.getEffectiveInstDataAsInt().empty()) {
377 tileShape = layout.getEffectiveInstDataAsInt();
380 return std::make_pair(tileShape, count);
384 TypeConverter converter;
385 converter.addConversion([](Type type) -> Type {
return type; });
386 converter.addConversion(
387 [&](RankedTensorType type,
388 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
389 Type elemTy = type.getElementType();
390 ArrayRef<int64_t> shape = type.getShape();
393 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
394 if (layout && layout.isForWorkgroup())
398 SmallVector<int64_t> subShape;
399 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
400 auto newTy = VectorType::get(subShape, elemTy);
401 result.append(count, newTy);
404 converter.addConversion(
405 [&](xegpu::TensorDescType type,
406 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
407 Type elemTy = type.getElementType();
408 ArrayRef<int64_t> shape = type.getShape();
410 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
411 if (layout && layout.isForWorkgroup())
415 SmallVector<int64_t> subShape;
416 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
419 layout = layout.dropInstData();
421 auto newTy = xegpu::TensorDescType::get(
422 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
423 result.append(count, newTy);
431 [&](Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
435 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
436 bool returnSingleType =
false) {
437 Type elemTy = type.getElementType();
440 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
442 Attribute encoding = tdescTy.getEncoding();
445 xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
446 tdescTy.getLayoutAttr().dropInstData());
448 newTy = VectorType::get(tileShape, elemTy);
451 if (returnSingleType)
452 return SmallVector<Type>{newTy};
453 std::optional<SmallVector<int64_t>> ratio =
455 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
459 RewritePatternSet patterns(ctx);
460 vector::UnrollVectorOptions vectorOptions;
464 vector::populateVectorUnrollPatterns(patterns, vectorOptions);
473 op->
walk([](Operation *op) {
484 if (
auto layout = op->
getAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
486 if (!isa<LoopLikeOpInterface>(op))
492 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
493 resolveUnrealizedConversionCastOp(castOp);
498 RewritePatternSet emptyPatterns(ctx);
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)