20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/DebugLog.h"
25#define GEN_PASS_DEF_XEGPUBLOCKING
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
30#define DEBUG_TYPE "xegpu-blocking"
42resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
47 auto types = values.getTypes();
48 return llvm::all_of(types, [&](
Type type) {
49 return isa<VectorType>(type) && type == types.front();
55 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
56 LDBG() <<
"skip unrealized conversion cast op not emulating pack/unpack.";
60 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
62 if (inputs.size() > 1 && outputs.size() == 1) {
66 builder, castOp.getLoc(), inputs,
shape);
69 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
73 builder, castOp.getLoc(), inputs[0], tileShape);
74 castOp->replaceAllUsesWith(results);
83struct ConvertLayoutOpPattern
86 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
88 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
89 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
90 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
91 targetLayout.getEffectiveInstDataAsInt().empty())
94 inputLayout = inputLayout.dropInstData();
95 targetLayout = targetLayout.dropInstData();
96 auto newOp = rewriter.
createOrFold<xegpu::ConvertLayoutOp>(
97 op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout);
111class XeGPUBlockingPass final
114 void runOnOperation()
override;
120 template <
typename T,
121 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
122 std::is_same_v<T, OpResult>>>
123 std::optional<SmallVector<int64_t>>
136template <
typename T,
typename>
137std::optional<SmallVector<int64_t>>
138XeGPUBlockingPass::getTileShape(
const T &operandOrResult)
const {
140 if constexpr (std::is_same_v<T, OpOperand>)
141 value = operandOrResult.get();
143 value = (Value)operandOrResult;
145 xegpu::DistributeLayoutAttr layout =
147 if (layout && layout.isForSubgroup()) {
148 if (!layout.getEffectiveInstDataAsInt().empty()) {
149 SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
158 bool skipLeadingUnitDimRemoval =
160 (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
161 xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
162 if (!skipLeadingUnitDimRemoval) {
163 auto it = llvm::find_if(instData, [](
auto val) {
return val != 1; });
164 instData.erase(instData.begin(), it);
169 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
170 return llvm::to_vector(type.getShape());
172 LDBG() <<
"failed to getTileShape for: " << value;
176std::optional<SmallVector<int64_t>>
177XeGPUBlockingPass::getTileShape(Operation *op)
const {
178 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
179 xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
181 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
182 xegpu::StoreMatrixOp>(op))
184 if (isa<xegpu::StoreNdOp>(op))
188 if (
auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
189 if (loadGatherOp.getOffsets())
195 if (
auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
197 ? storeScatterOp->getOpOperand(0)
198 : storeScatterOp->getOpOperand(1));
200 if (isa<xegpu::DpasOp>(op)) {
201 std::optional<SmallVector<int64_t>> aTile =
203 std::optional<SmallVector<int64_t>> bTile =
206 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
210 if ((*aTile)[1] != (*bTile)[0])
215 std::optional<SmallVector<int64_t>> cTile =
217 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
218 if (!cTile || !llvm::equal(*cTile, expectedCTile))
222 return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
228 if (isa<vector::MultiDimReductionOp>(op))
231 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
237bool XeGPUBlockingPass::needsUnroll(Operation *op)
const {
239 bool hasWgLayoutOperands =
241 xegpu::DistributeLayoutAttr layout =
242 xegpu::getDistributeLayoutAttr(opr);
243 return layout && layout.isForWorkgroup();
245 bool hasWgLayoutResults =
247 xegpu::DistributeLayoutAttr layout =
248 xegpu::getDistributeLayoutAttr(result);
249 return layout && layout.isForWorkgroup();
251 if (hasWgLayoutOperands || hasWgLayoutResults) {
252 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
256 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
258 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
259 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
260 return layout && !layout.getEffectiveInstDataAsInt().empty();
262 auto shapedType = dyn_cast<ShapedType>(valTy);
263 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
266 bool hasUnrollableOperands =
268 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
269 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
271 bool hasUnrollableResults =
273 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
274 return tileShape.has_value() && isUnrollable(result, *tileShape);
276 return hasUnrollableOperands || hasUnrollableResults;
279void XeGPUBlockingPass::runOnOperation() {
281 Operation *op = getOperation();
291 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
292 xegpu::LayoutAttr layout) {
294 SmallVector<int64_t> tileShape(shape);
295 if (layout && layout.getInstData()) {
297 tileShape = llvm::to_vector_of<int64_t>(instData.
asArrayRef());
300 return std::make_pair(tileShape, count);
304 TypeConverter converter;
305 converter.addConversion([](Type type) -> Type {
return type; });
306 converter.addConversion(
307 [&](RankedTensorType type,
308 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
309 Type elemTy = type.getElementType();
310 ArrayRef<int64_t> shape = type.getShape();
313 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
314 if (layout && layout.isForWorkgroup())
318 SmallVector<int64_t> subShape;
319 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
320 auto newTy = VectorType::get(subShape, elemTy);
321 result.append(count, newTy);
324 converter.addConversion(
325 [&](xegpu::TensorDescType type,
326 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
327 Type elemTy = type.getElementType();
328 ArrayRef<int64_t> shape = type.getShape();
330 xegpu::LayoutAttr layout = type.getLayoutAttr();
331 if (layout && layout.isForWorkgroup())
335 SmallVector<int64_t> subShape;
336 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
339 layout = layout.dropInstData();
341 auto newTy = xegpu::TensorDescType::get(
342 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
343 result.append(count, newTy);
353 vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns);
358 [&](Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
362 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
363 bool returnSingleType =
false) {
364 Type elemTy = type.getElementType();
367 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
369 Attribute encoding = tdescTy.getEncoding();
372 if (tdescTy.isScattered()) {
373 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
376 int64_t blockedChunkSize = chunkSize;
377 auto instData = tdescTy.getLayoutAttr().getInstData();
378 if (!instData.empty())
379 blockedChunkSize = instData.asArrayRef().back();
382 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
383 ctx, tdescTy.getMemorySpace(), blockedChunkSize);
384 encoding = newEncoding;
389 xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
390 tdescTy.getLayoutAttr().dropInstData());
392 newTy = VectorType::get(tileShape, elemTy);
395 if (returnSingleType)
396 return SmallVector<Type>{newTy};
397 std::optional<SmallVector<int64_t>> ratio =
399 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
404 patterns.add<ConvertLayoutOpPattern>(ctx);
406 vector::UnrollVectorOptions vectorOptions;
407 vectorOptions.setNativeShapeFn(
options.nativeShape);
410 vector::populateVectorUnrollPatterns(
patterns, vectorOptions);
414 op->
walk([](Operation *op) {
425 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
427 if (!isa<LoopLikeOpInterface>(op))
433 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
434 resolveUnrealizedConversionCastOp(castOp);
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ArrayRef< T > asArrayRef() const
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
void recoverTemporaryLayoutsDeprecated(Operation *op)
[to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and OpResult of of the given opera...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...