19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/DebugLog.h"
24 #define GEN_PASS_DEF_XEGPUBLOCKING
25 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "xegpu-blocking"
41 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
45 auto hasIdenticalVectorTypes = [](
ValueRange values) {
47 return llvm::all_of(types, [&](
Type type) {
48 return isa<VectorType>(type) && type == types.front();
54 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
55 LDBG() <<
"skip unrealized conversion cast op not emulating pack/unpack.";
59 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
61 if (inputs.size() > 1 && outputs.size() == 1) {
65 builder, castOp.getLoc(), inputs, shape);
66 castOp->replaceAllUsesWith(
ValueRange(result));
68 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
72 builder, castOp.getLoc(), inputs[0], tileShape);
73 castOp->replaceAllUsesWith(results);
82 struct ConvertLayoutOpPattern
85 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
87 xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
88 xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
89 if (!input_layout.getInstData() || !target_layout.getInstData())
92 input_layout = input_layout.dropInstData();
93 target_layout = target_layout.dropInstData();
94 auto newOp = rewriter.
createOrFold<xegpu::ConvertLayoutOp>(
95 op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
109 class XeGPUBlockingPass final
110 :
public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
112 void runOnOperation()
override;
118 template <
typename T,
119 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
120 std::is_same_v<T, OpResult>>>
121 std::optional<SmallVector<int64_t>>
134 template <
typename T,
typename>
135 std::optional<SmallVector<int64_t>>
138 if constexpr (std::is_same_v<T, OpOperand>)
139 value = operandOrResult.get();
141 value = (
Value)operandOrResult;
144 if (layout && layout.isSgLayout()) {
145 if (
auto inst_data = layout.getInstData())
146 return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
148 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
149 return llvm::to_vector(type.getShape());
151 LDBG() <<
"failed to getTileShape for: " << value;
155 std::optional<SmallVector<int64_t>>
157 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
158 xegpu::UpdateOffsetOp>(op))
160 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
161 xegpu::LoadGatherOp>(op))
163 if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
166 if (isa<xegpu::DpasOp>(op)) {
167 std::optional<SmallVector<int64_t>> aTile =
169 std::optional<SmallVector<int64_t>> bTile =
172 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
176 if ((*aTile)[1] != (*bTile)[0])
181 std::optional<SmallVector<int64_t>> cTile =
183 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
184 if (!cTile || !llvm::equal(*cTile, expectedCTile))
194 if (isa<vector::MultiDimReductionOp>(op))
197 if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
203 bool XeGPUBlockingPass::needsUnroll(
Operation *op)
const {
205 bool hasWgLayoutOperands =
207 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
208 return layout && layout.isWgLayout();
210 bool hasWgLayoutResults =
212 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
213 return layout && layout.isWgLayout();
215 if (hasWgLayoutOperands || hasWgLayoutResults) {
216 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
222 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
223 xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
224 return layout && layout.getInstData();
226 auto shapedType = dyn_cast<ShapedType>(valTy);
227 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
230 bool hasUnrollableOperands =
232 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
233 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
235 bool hasUnrollableResults =
237 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
238 return tileShape.has_value() && isUnrollable(result, *tileShape);
240 return hasUnrollableOperands || hasUnrollableResults;
243 void XeGPUBlockingPass::runOnOperation() {
253 xegpu::LayoutAttr layout) {
256 if (layout && layout.getInstData()) {
258 tileShape = llvm::to_vector_of<int64_t>(instData.
asArrayRef());
261 return std::make_pair(tileShape, count);
268 [&](RankedTensorType type,
270 Type elemTy = type.getElementType();
274 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
275 if (layout && layout.isWgLayout())
280 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
282 result.append(count, newTy);
286 [&](xegpu::TensorDescType type,
288 Type elemTy = type.getElementType();
291 xegpu::LayoutAttr layout = type.getLayoutAttr();
292 if (layout && layout.isWgLayout())
297 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
300 layout = layout.dropInstData();
303 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
304 result.append(count, newTy);
312 [&](
Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
317 Type elemTy = type.getElementType();
320 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
322 Attribute encoding = tdescTy.getEncoding();
325 if (tdescTy.isScattered()) {
326 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
329 int64_t blockedChunkSize = chunkSize;
330 auto instData = tdescTy.getLayoutAttr().getInstData();
331 if (!instData.empty())
332 blockedChunkSize = instData.asArrayRef().back();
336 ctx, tdescTy.getMemorySpace(), blockedChunkSize);
338 encoding = newEncoding;
344 tdescTy.getLayoutAttr().dropInstData());
346 newTy = type.clone(tileShape, elemTy);
349 std::optional<SmallVector<int64_t>> ratio =
351 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
356 patterns.add<ConvertLayoutOpPattern>(ctx);
362 vector::populateVectorUnrollPatterns(
patterns, vectorOptions);
377 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
379 if (!isa<LoopLikeOpInterface>(op))
385 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
386 resolveUnrealizedConversionCastOp(castOp);
static MLIRContext * getContext(OpFoldResult val)
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
OpOperand & getOpOperand(unsigned idx)
bool hasAttrOfType(NameT &&name)
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
LayoutAttr getLayoutAttr(const Value value)
Retrieves the LayoutAttr associated with a given Value.
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)
Set the LayoutAttr for each OpOperand and OpResult of the given operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Options to control the XeGPU unrolling.