MLIR 23.0.0git
XeGPUBlocking.cpp
Go to the documentation of this file.
1//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/DebugLog.h"
23
24namespace mlir {
25namespace xegpu {
26#define GEN_PASS_DEF_XEGPUBLOCKING
27#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
28} // namespace xegpu
29} // namespace mlir
30
31#define DEBUG_TYPE "xegpu-blocking"
32
33using namespace mlir;
34
35namespace {
37// reslove the unrealized conversion cast ops generated when doing SCF
38// Structural Type Conversion. It will have two formats, N:1 vector
39// cast and 1:N vector cast. vector::insert_strided_slice ops will be
40// used for the first case, and vector::extract_strided_slice ops will be
41// used for the second case.
42static void
43resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
44 ValueRange inputs = castOp.getInputs();
45 ValueRange outputs = castOp.getOutputs();
47 auto hasIdenticalVectorTypes = [](ValueRange values) {
48 auto types = values.getTypes();
49 return llvm::all_of(types, [&](Type type) {
50 return isa<VectorType>(type) && type == types.front();
51 });
52 };
53
54 // We only interest in the case where all inputs and outputs have the
55 // identical VectorTypes
56 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
57 LDBG() << "skip unrealized conversion cast op not emulating pack/unpack.";
58 return;
59 }
60
61 VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
62 OpBuilder builder(castOp);
63 if (inputs.size() > 1 && outputs.size() == 1) {
64 // the castOp is emulating an unpack op
65 ArrayRef<int64_t> shape = outputTy.getShape();
67 builder, castOp.getLoc(), inputs, shape);
68 castOp->replaceAllUsesWith(ValueRange(result));
69 castOp->erase();
70 } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
71 // the castOp is emulating a pack op
72 ArrayRef<int64_t> tileShape = outputTy.getShape();
74 builder, castOp.getLoc(), inputs[0], tileShape);
75 castOp->replaceAllUsesWith(results);
76 castOp->erase();
77 }
78}
79
80//===------------------------------------------------------------------------===//
81// The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
82// to partition operations that process large shapes into multiple operations on
83// smaller shapes, as specified by the inst_data in the layout attribute. This
84// enables each resulting operation to be efficiently mapped to a hardware
85// instruction.
86//===------------------------------------------------------------------------===//
87
88class XeGPUBlockingPass final
89 : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
90public:
91 void runOnOperation() override;
92
93private:
94 // Get the tile shape for a given OpOperand or OpResult by examining the
95 // corresponding layout attribute. If layout is not present or is not a
96 // subgroup level layout, it returns std::nullopt.
97 template <typename T,
98 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
99 std::is_same_v<T, OpResult>>>
100 std::optional<SmallVector<int64_t>>
101 getTileShape(const T &operandOrResult) const;
102
103 // Get the tile shape for a given operation.
104 std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
105
106 // Determine if the operation requires unrolling. Return false if all operands
107 // and results have tile shapes identical to their original types. Otherwise,
108 // return true.
109 bool needsUnroll(Operation *op) const;
110};
111} // namespace
112
113template <typename T, typename>
114std::optional<SmallVector<int64_t>>
115XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
116 Value value;
117 if constexpr (std::is_same_v<T, OpOperand>) {
118 value = operandOrResult.get();
119 } else {
120 value = (Value)operandOrResult;
121 }
122
123 xegpu::DistributeLayoutAttr layout =
124 xegpu::getDistributeLayoutAttr(operandOrResult);
125 if (layout && layout.isForSubgroup()) {
126 if (!layout.getEffectiveInstDataAsInt().empty()) {
127 SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
128 return instData;
129 }
130 if (auto type = dyn_cast<ShapedType>(value.getType()))
131 return llvm::to_vector(type.getShape());
132 }
133 LDBG() << "failed to getTileShape for: " << value;
134 return std::nullopt;
135}
136
137std::optional<SmallVector<int64_t>>
138XeGPUBlockingPass::getTileShape(Operation *op) const {
139 if (isa<xegpu::CreateNdDescOp, xegpu::LoadMatrixOp>(op))
140 return getTileShape(op->getOpResult(0));
141 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
142 xegpu::StoreMatrixOp>(op))
143 return getTileShape(op->getOpOperand(0));
144 if (isa<xegpu::StoreNdOp>(op))
145 return getTileShape(op->getOpOperand(1));
146
147 if (isa<xegpu::LoadGatherOp>(op))
148 return getTileShape(op->getOpResult(0));
149
150 if (auto convertLayoutOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
151 auto inputInstData =
152 convertLayoutOp.getInputLayout().getEffectiveInstDataAsInt();
153 auto targetInstData =
154 convertLayoutOp.getTargetLayout().getEffectiveInstDataAsInt();
155 // return the one with larger size
156 if (computeProduct(inputInstData) >= computeProduct(targetInstData))
157 return inputInstData;
158 else
159 return targetInstData;
160 }
161
162 if (isa<xegpu::StoreScatterOp>(op))
163 return getTileShape(op->getOpOperand(0));
164
165 // Helper lambda to validate and get A/B tiles
166 auto validateABTiles = [&](Operation *op)
167 -> std::optional<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> {
168 std::optional<SmallVector<int64_t>> aTile =
170 std::optional<SmallVector<int64_t>> bTile =
172
173 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
174 return std::nullopt;
175
176 // semantic check for A and B
177 if ((*aTile)[1] != (*bTile)[0])
178 return std::nullopt;
179
180 return std::make_pair(*aTile, *bTile);
181 };
182
183 // Helper lambda to validate C tile
184 auto validateCTile = [&](Operation *op, unsigned cOperandIdx,
185 const SmallVector<int64_t> &aTile,
186 const SmallVector<int64_t> &bTile) -> bool {
187 if (op->getNumOperands() <= cOperandIdx)
188 return true;
189
190 std::optional<SmallVector<int64_t>> cTile =
191 getTileShape(op->getOpOperand(cOperandIdx));
192 int64_t expectedCTile[2] = {aTile[0], bTile[1]};
193 if (!cTile || !llvm::equal(*cTile, expectedCTile))
194 return false;
195 return true;
196 };
197
198 // Helper lambda to validate scale A tile for DpasMxOp
199 auto validateScaleATile =
200 [&](Operation *op, unsigned scaleAOperandIdx,
201 const SmallVector<int64_t> &aTile) -> std::optional<int64_t> {
202 std::optional<SmallVector<int64_t>> aScaleTile =
203 getTileShape(op->getOpOperand(scaleAOperandIdx));
204
205 if (!aScaleTile || aScaleTile->size() != 2)
206 return std::nullopt;
207
208 // Validate scale_a tile: [M_tile, K_scale]
209 // M dimension must match A's M dimension
210 if ((*aScaleTile)[0] != aTile[0])
211 return std::nullopt;
212
213 // Return the K scale factor
214 return (*aScaleTile)[1];
215 };
216
217 // Helper lambda to validate scale B tile for DpasMxOp
218 auto validateScaleBTile =
219 [&](Operation *op, unsigned scaleBOperandIdx,
220 const SmallVector<int64_t> &bTile) -> std::optional<int64_t> {
221 std::optional<SmallVector<int64_t>> bScaleTile =
222 getTileShape(op->getOpOperand(scaleBOperandIdx));
223
224 if (!bScaleTile || bScaleTile->size() != 2)
225 return std::nullopt;
226
227 // Validate scale_b tile: [K_scale, N_tile]
228 // N dimension must match B's N dimension
229 if ((*bScaleTile)[1] != bTile[1])
230 return std::nullopt;
231
232 // Return the K scale factor
233 return (*bScaleTile)[0];
234 };
235
236 if (isa<xegpu::DpasOp>(op)) {
237 auto abTiles = validateABTiles(op);
238 if (!abTiles)
239 return std::nullopt;
240
241 auto [aTile, bTile] = *abTiles;
242
243 // semantic check for C
244 if (!validateCTile(op, 2, aTile, bTile))
245 return std::nullopt;
246
247 return SmallVector<int64_t>({aTile[0], aTile[1], bTile[1]});
248 }
249
250 if (auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
251 auto abTiles = validateABTiles(op);
252 if (!abTiles)
253 return std::nullopt;
254
255 auto [aTile, bTile] = *abTiles;
256
257 // Validate C tile if present using op-specific accessor
258 if (dpasMxOp.getAcc()) {
259 unsigned accOperandIdx = 2; // acc is the 3rd operand
260 if (!validateCTile(op, accOperandIdx, aTile, bTile))
261 return std::nullopt;
262 }
263
264 // Validate scale tiles if present using op-specific accessors
265 int64_t kScaleFactor = 1;
266 std::optional<int64_t> scaleAFactor;
267 std::optional<int64_t> scaleBFactor;
268
269 if (dpasMxOp.getScaleA()) {
270 unsigned scaleAOperandIdx = 2 + (dpasMxOp.getAcc() ? 1 : 0);
271 scaleAFactor = validateScaleATile(op, scaleAOperandIdx, aTile);
272 if (!scaleAFactor)
273 return std::nullopt;
274 }
275
276 if (dpasMxOp.getScaleB()) {
277 unsigned scaleBOperandIdx =
278 2 + (dpasMxOp.getAcc() ? 1 : 0) + (dpasMxOp.getScaleA() ? 1 : 0);
279 scaleBFactor = validateScaleBTile(op, scaleBOperandIdx, bTile);
280 if (!scaleBFactor)
281 return std::nullopt;
282 }
283
284 // If both scales are present, their K dimensions must match
285 if (scaleAFactor && scaleBFactor) {
286 if (*scaleAFactor != *scaleBFactor)
287 return std::nullopt;
288 kScaleFactor = *scaleAFactor;
289 } else if (scaleAFactor) {
290 kScaleFactor = *scaleAFactor;
291 } else if (scaleBFactor) {
292 kScaleFactor = *scaleBFactor;
293 }
294
295 return SmallVector<int64_t>({aTile[0], aTile[1], bTile[1], kScaleFactor});
296 }
297
299 return getTileShape(op->getOpResult(0));
300
301 if (isa<vector::MultiDimReductionOp>(op))
302 return getTileShape(op->getOpOperand(0));
303
304 if (isa<vector::TransposeOp, vector::BroadcastOp, vector::StepOp,
305 vector::ShapeCastOp, vector::ConstantMaskOp, vector::CreateMaskOp,
306 vector::BitCastOp, vector::InterleaveOp, vector::DeinterleaveOp>(op))
307 return getTileShape(op->getOpResult(0));
308
309 return std::nullopt;
310}
311
312bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
313 // skip the op if any of its operands or results has workgroup level layouts
314 bool hasWgLayoutOperands =
315 llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
316 xegpu::DistributeLayoutAttr layout =
317 xegpu::getDistributeLayoutAttr(opr);
318 return layout && layout.isForWorkgroup();
319 });
320 bool hasWgLayoutResults =
321 llvm::any_of(op->getOpResults(), [](OpResult result) {
322 xegpu::DistributeLayoutAttr layout =
323 xegpu::getDistributeLayoutAttr(result);
324 return layout && layout.isForWorkgroup();
325 });
326 if (hasWgLayoutOperands || hasWgLayoutResults) {
327 LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
328 return false;
329 }
330
331 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
332 Type valTy = value.getType();
333 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
334 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
335 return layout && !layout.getEffectiveInstDataAsInt().empty();
336 }
337 auto shapedType = dyn_cast<ShapedType>(valTy);
338 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
339 };
340
341 bool hasUnrollableOperands =
342 llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) {
343 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
344 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
345 });
346 bool hasUnrollableResults =
347 llvm::any_of(op->getOpResults(), [&](OpResult result) {
348 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
349 return tileShape.has_value() && isUnrollable(result, *tileShape);
350 });
351 // ConvertLayoutOp must be processed to drop the inst_data in the layout
352 bool isConvertLayoutWithInstData = false;
353 if (auto convertLayoutOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
354 auto targettLayout = convertLayoutOp.getTargetLayout();
355 if (targettLayout && !targettLayout.getEffectiveInstDataAsInt().empty()) {
356 isConvertLayoutWithInstData = true;
357 }
358 }
359 return hasUnrollableOperands || hasUnrollableResults ||
360 isConvertLayoutWithInstData;
361}
362
363void XeGPUBlockingPass::runOnOperation() {
364 MLIRContext *ctx = &getContext();
365 Operation *op = getOperation();
366
368 signalPassFailure();
369 return;
370 }
371
372 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
373 xegpu::DistributeLayoutAttr layout) {
374 int count = 1;
375 SmallVector<int64_t> tileShape(shape);
376 if (layout && !layout.getEffectiveInstDataAsInt().empty()) {
377 tileShape = layout.getEffectiveInstDataAsInt();
378 count = computeProduct(shape) / computeProduct(tileShape);
379 }
380 return std::make_pair(tileShape, count);
381 };
382
383 // Perform type conversion for SCF control folow ops
384 TypeConverter converter;
385 converter.addConversion([](Type type) -> Type { return type; });
386 converter.addConversion(
387 [&](RankedTensorType type,
388 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
389 Type elemTy = type.getElementType();
390 ArrayRef<int64_t> shape = type.getShape();
391
392 auto layout =
393 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
394 if (layout && layout.isForWorkgroup())
395 return failure();
396
397 int count;
398 SmallVector<int64_t> subShape;
399 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
400 auto newTy = VectorType::get(subShape, elemTy);
401 result.append(count, newTy);
402 return success();
403 });
404 converter.addConversion(
405 [&](xegpu::TensorDescType type,
406 SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
407 Type elemTy = type.getElementType();
408 ArrayRef<int64_t> shape = type.getShape();
409
410 xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
411 if (layout && layout.isForWorkgroup())
412 return failure();
413
414 int count;
415 SmallVector<int64_t> subShape;
416 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
417
418 if (layout)
419 layout = layout.dropInstData();
420
421 auto newTy = xegpu::TensorDescType::get(
422 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
423 result.append(count, newTy);
424 return success();
425 });
426
428
429 xegpu::UnrollOptions options;
430 options.setFilterConstraint(
431 [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
432
433 options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
434
435 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
436 bool returnSingleType = false) {
437 Type elemTy = type.getElementType();
438 Type newTy;
439
440 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
441
442 Attribute encoding = tdescTy.getEncoding();
443
444 newTy =
445 xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
446 tdescTy.getLayoutAttr().dropInstData());
447 } else {
448 newTy = VectorType::get(tileShape, elemTy);
449 }
450
451 if (returnSingleType)
452 return SmallVector<Type>{newTy};
453 std::optional<SmallVector<int64_t>> ratio =
454 computeShapeRatio(type.getShape(), tileShape);
455 assert(ratio && "The shape of the type must be a multiple of tileShape.");
456 return SmallVector<Type>(computeProduct(*ratio), newTy);
457 });
458
459 RewritePatternSet patterns(ctx);
460 vector::UnrollVectorOptions vectorOptions;
461 vectorOptions.setNativeShapeFn(options.nativeShape);
462
464 vector::populateVectorUnrollPatterns(patterns, vectorOptions);
465
466 // Note: The pattern driver does op folding as well and clean up.
467 // But intermediate insert/extract strided slice ops with
468 // unrealized conversion cast ops in the middle does not get
469 // cleaned up in this step. One more round of folding is needed
470 // after the walk to resolve those unrealized conversion cast ops.
471 (void)applyPatternsGreedily(op, std::move(patterns));
472
473 op->walk([](Operation *op) {
474 // Remove the layout attributes cached per operands.
475 for (OpOperand &opr : op->getOpOperands()) {
476 std::string name = xegpu::getTemporaryLayoutName(opr);
477 if (op->hasAttrOfType<xegpu::DistributeLayoutAttr>(name))
478 op->removeAttr(name);
479 }
480
481 // Update the layout attributes per result.
482 for (OpResult result : op->getOpResults()) {
483 std::string name = xegpu::getTemporaryLayoutName(result);
484 if (auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
485 op->removeAttr(name);
486 if (!isa<LoopLikeOpInterface>(op))
487 xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
488 }
489 }
490
491 // Resolve unrealized conversion cast ops emulating pack/unpack
492 if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
493 resolveUnrealizedConversionCastOp(castOp);
494 });
495
496 // One more round of folding to clean up the intermediate
497 // insert/extract strided slice ops.
498 RewritePatternSet emptyPatterns(ctx);
499 (void)applyPatternsGreedily(op, std::move(emptyPatterns));
500}
return success()
b getContext())
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition MMAUtils.cpp:37
static llvm::ManagedStatic< PassManagerOptions > options
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getOpResult(unsigned idx)
Definition Operation.h:447
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:576
bool hasAttrOfType(NameT &&name)
Definition Operation.h:601
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
unsigned getNumOperands()
Definition Operation.h:372
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
result_range getOpResults()
Definition Operation.h:446
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition Operation.h:626
OpOperand & getOpOperand(unsigned idx)
Definition Operation.h:414
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)