19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
24 #define GEN_PASS_DEF_XEGPUBLOCKING
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "xegpu-blocking"
41 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
45 auto hasIdenticalVectorTypes = [](
ValueRange values) {
47 return llvm::all_of(types, [&](
Type type) {
48 return isa<VectorType>(type) && type == types.front();
54 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
55 LDBG() <<
"skip unrealized conversion cast op not emulating pack/unpack.";
59 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
61 if (inputs.size() > 1 && outputs.size() == 1) {
65 builder, castOp.getLoc(), inputs, shape);
66 castOp->replaceAllUsesWith(
ValueRange(result));
68 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
72 builder, castOp.getLoc(), inputs[0], tileShape);
73 castOp->replaceAllUsesWith(results);
82 struct ConvertLayoutOpPattern
85 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
87 xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
88 xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
89 if (input_layout.getInstDataAsInt().empty() ||
90 target_layout.getInstDataAsInt().empty())
93 input_layout = input_layout.dropInstData();
94 target_layout = target_layout.dropInstData();
95 auto newOp = rewriter.
createOrFold<xegpu::ConvertLayoutOp>(
96 op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
110 class XeGPUBlockingPass final
111 :
public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
113 void runOnOperation()
override;
119 template <
typename T,
120 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
121 std::is_same_v<T, OpResult>>>
122 std::optional<SmallVector<int64_t>>
135 template <
typename T,
typename>
136 std::optional<SmallVector<int64_t>>
139 if constexpr (std::is_same_v<T, OpOperand>)
140 value = operandOrResult.get();
142 value = (
Value)operandOrResult;
144 xegpu::DistributeLayoutAttr layout =
146 if (layout && layout.isForSubgroup()) {
147 if (!layout.getInstDataAsInt().empty())
148 return layout.getInstDataAsInt();
150 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
151 return llvm::to_vector(type.getShape());
153 LDBG() <<
"failed to getTileShape for: " << value;
157 std::optional<SmallVector<int64_t>>
159 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
160 xegpu::UpdateOffsetOp>(op))
162 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
163 xegpu::LoadGatherOp>(op))
165 if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
168 if (isa<xegpu::DpasOp>(op)) {
169 std::optional<SmallVector<int64_t>> aTile =
171 std::optional<SmallVector<int64_t>> bTile =
174 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
178 if ((*aTile)[1] != (*bTile)[0])
183 std::optional<SmallVector<int64_t>> cTile =
185 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
186 if (!cTile || !llvm::equal(*cTile, expectedCTile))
196 if (isa<vector::MultiDimReductionOp>(op))
199 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
205 bool XeGPUBlockingPass::needsUnroll(
Operation *op)
const {
207 bool hasWgLayoutOperands =
209 xegpu::DistributeLayoutAttr layout =
210 xegpu::getDistributeLayoutAttr(opr);
211 return layout && layout.isForWorkgroup();
213 bool hasWgLayoutResults =
215 xegpu::DistributeLayoutAttr layout =
216 xegpu::getDistributeLayoutAttr(result);
217 return layout && layout.isForWorkgroup();
219 if (hasWgLayoutOperands || hasWgLayoutResults) {
220 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
226 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
227 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
228 return layout && !layout.getInstDataAsInt().empty();
230 auto shapedType = dyn_cast<ShapedType>(valTy);
231 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
234 bool hasUnrollableOperands =
236 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
237 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
239 bool hasUnrollableResults =
241 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
242 return tileShape.has_value() && isUnrollable(result, *tileShape);
244 return hasUnrollableOperands || hasUnrollableResults;
247 void XeGPUBlockingPass::runOnOperation() {
258 xegpu::LayoutAttr layout) {
261 if (layout && layout.getInstData()) {
263 tileShape = llvm::to_vector_of<int64_t>(instData.
asArrayRef());
266 return std::make_pair(tileShape, count);
273 [&](RankedTensorType type,
275 Type elemTy = type.getElementType();
279 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
280 if (layout && layout.isForWorkgroup())
285 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
287 result.append(count, newTy);
291 [&](xegpu::TensorDescType type,
293 Type elemTy = type.getElementType();
296 xegpu::LayoutAttr layout = type.getLayoutAttr();
297 if (layout && layout.isForWorkgroup())
302 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
305 layout = layout.dropInstData();
308 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
309 result.append(count, newTy);
317 [&](
Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
322 Type elemTy = type.getElementType();
325 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
327 Attribute encoding = tdescTy.getEncoding();
330 if (tdescTy.isScattered()) {
331 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
334 int64_t blockedChunkSize = chunkSize;
335 auto instData = tdescTy.getLayoutAttr().getInstData();
336 if (!instData.empty())
337 blockedChunkSize = instData.asArrayRef().back();
341 ctx, tdescTy.getMemorySpace(), blockedChunkSize);
343 encoding = newEncoding;
349 tdescTy.getLayoutAttr().dropInstData());
351 newTy = type.clone(tileShape, elemTy);
354 std::optional<SmallVector<int64_t>> ratio =
356 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
361 patterns.add<ConvertLayoutOpPattern>(ctx);
367 vector::populateVectorUnrollPatterns(
patterns, vectorOptions);
382 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
384 if (!isa<LoopLikeOpInterface>(op))
390 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
391 resolveUnrealizedConversionCastOp(castOp);
static MLIRContext * getContext(OpFoldResult val)
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
OpOperand & getOpOperand(unsigned idx)
bool hasAttrOfType(NameT &&name)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttrs(Operation *op, function_ref< DistributeLayoutAttr(Value)> getLayoutImpl)
Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Options to control the XeGPU unrolling.