21#include "llvm/ADT/STLExtras.h"
22#include "llvm/Support/DebugLog.h"
26#define GEN_PASS_DEF_XEGPUBLOCKING
27#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
31#define DEBUG_TYPE "xegpu-blocking"
43resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
48 auto types = values.getTypes();
49 return llvm::all_of(types, [&](
Type type) {
50 return isa<VectorType>(type) && type == types.front();
56 if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
57 LDBG() <<
"skip unrealized conversion cast op not emulating pack/unpack.";
61 VectorType outputTy = dyn_cast<VectorType>(outputs[0].
getType());
63 if (inputs.size() > 1 && outputs.size() == 1) {
67 builder, castOp.getLoc(), inputs,
shape);
70 }
else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
74 builder, castOp.getLoc(), inputs[0], tileShape);
75 castOp->replaceAllUsesWith(results);
84struct ConvertLayoutOpPattern
87 LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
89 xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
90 xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
91 if (inputLayout.getEffectiveInstDataAsInt().empty() ||
92 targetLayout.getEffectiveInstDataAsInt().empty())
95 inputLayout = inputLayout.dropInstData();
96 targetLayout = targetLayout.dropInstData();
97 auto newOp = rewriter.
createOrFold<xegpu::ConvertLayoutOp>(
98 op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout);
112class XeGPUBlockingPass final
115 void runOnOperation()
override;
121 template <
typename T,
122 typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
123 std::is_same_v<T, OpResult>>>
124 std::optional<SmallVector<int64_t>>
137template <
typename T,
typename>
138std::optional<SmallVector<int64_t>>
139XeGPUBlockingPass::getTileShape(
const T &operandOrResult)
const {
141 if constexpr (std::is_same_v<T, OpOperand>) {
142 value = operandOrResult.get();
144 value = (Value)operandOrResult;
147 xegpu::DistributeLayoutAttr layout =
149 if (layout && layout.isForSubgroup()) {
150 if (!layout.getEffectiveInstDataAsInt().empty()) {
151 SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
154 if (
auto type = dyn_cast<ShapedType>(value.
getType()))
155 return llvm::to_vector(type.getShape());
157 LDBG() <<
"failed to getTileShape for: " << value;
161std::optional<SmallVector<int64_t>>
162XeGPUBlockingPass::getTileShape(Operation *op)
const {
163 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
164 xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
166 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
167 xegpu::StoreMatrixOp>(op))
169 if (isa<xegpu::StoreNdOp>(op))
173 if (
auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
174 if (loadGatherOp.getOffsets())
180 if (
auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
182 ? storeScatterOp->getOpOperand(0)
183 : storeScatterOp->getOpOperand(1));
185 if (isa<xegpu::DpasOp>(op)) {
186 std::optional<SmallVector<int64_t>> aTile =
188 std::optional<SmallVector<int64_t>> bTile =
191 if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
195 if ((*aTile)[1] != (*bTile)[0])
200 std::optional<SmallVector<int64_t>> cTile =
202 int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
203 if (!cTile || !llvm::equal(*cTile, expectedCTile))
207 return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
213 if (isa<vector::MultiDimReductionOp>(op))
216 if (isa<vector::TransposeOp, vector::BroadcastOp, vector::StepOp,
217 vector::ConstantMaskOp, vector::CreateMaskOp>(op))
223bool XeGPUBlockingPass::needsUnroll(Operation *op)
const {
225 bool hasWgLayoutOperands =
227 xegpu::DistributeLayoutAttr layout =
228 xegpu::getDistributeLayoutAttr(opr);
229 return layout && layout.isForWorkgroup();
231 bool hasWgLayoutResults =
233 xegpu::DistributeLayoutAttr layout =
234 xegpu::getDistributeLayoutAttr(result);
235 return layout && layout.isForWorkgroup();
237 if (hasWgLayoutOperands || hasWgLayoutResults) {
238 LDBG() <<
"skip unrolling for op with workgroup level layout: " << *op;
242 auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
244 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
245 xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
246 return layout && !layout.getEffectiveInstDataAsInt().empty();
248 auto shapedType = dyn_cast<ShapedType>(valTy);
249 return shapedType && !llvm::equal(tileShape, shapedType.getShape());
252 bool hasUnrollableOperands =
254 std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
255 return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
257 bool hasUnrollableResults =
259 std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
260 return tileShape.has_value() && isUnrollable(result, *tileShape);
262 return hasUnrollableOperands || hasUnrollableResults;
265void XeGPUBlockingPass::runOnOperation() {
267 Operation *op = getOperation();
274 auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
275 xegpu::LayoutAttr layout) {
277 SmallVector<int64_t> tileShape(shape);
278 if (layout && layout.getInstData()) {
280 tileShape = llvm::to_vector_of<int64_t>(instData.
asArrayRef());
283 return std::make_pair(tileShape, count);
287 TypeConverter converter;
288 converter.addConversion([](Type type) -> Type {
return type; });
289 converter.addConversion(
290 [&](RankedTensorType type,
291 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
292 Type elemTy = type.getElementType();
293 ArrayRef<int64_t> shape = type.getShape();
296 llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
297 if (layout && layout.isForWorkgroup())
301 SmallVector<int64_t> subShape;
302 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
303 auto newTy = VectorType::get(subShape, elemTy);
304 result.append(count, newTy);
307 converter.addConversion(
308 [&](xegpu::TensorDescType type,
309 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
310 Type elemTy = type.getElementType();
311 ArrayRef<int64_t> shape = type.getShape();
313 xegpu::LayoutAttr layout = type.getLayoutAttr();
314 if (layout && layout.isForWorkgroup())
318 SmallVector<int64_t> subShape;
319 std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
322 layout = layout.dropInstData();
324 auto newTy = xegpu::TensorDescType::get(
325 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
326 result.append(count, newTy);
334 [&](Operation *op) -> LogicalResult {
return success(needsUnroll(op)); });
338 options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
339 bool returnSingleType =
false) {
340 Type elemTy = type.getElementType();
343 if (
auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
345 Attribute encoding = tdescTy.getEncoding();
348 if (tdescTy.isScattered()) {
349 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
352 int64_t blockedChunkSize = chunkSize;
353 auto instData = tdescTy.getLayoutAttr().getInstData();
354 if (!instData.empty())
355 blockedChunkSize = instData.asArrayRef().back();
358 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
359 ctx, tdescTy.getMemorySpace(), blockedChunkSize);
360 encoding = newEncoding;
365 xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
366 tdescTy.getLayoutAttr().dropInstData());
368 newTy = VectorType::get(tileShape, elemTy);
371 if (returnSingleType)
372 return SmallVector<Type>{newTy};
373 std::optional<SmallVector<int64_t>> ratio =
375 assert(ratio &&
"The shape of the type must be a multiple of tileShape.");
380 patterns.add<ConvertLayoutOpPattern>(ctx);
382 vector::UnrollVectorOptions vectorOptions;
383 vectorOptions.setNativeShapeFn(
options.nativeShape);
386 vector::populateVectorUnrollPatterns(
patterns, vectorOptions);
390 op->
walk([](Operation *op) {
401 if (
auto layout = op->
getAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
403 if (!isa<LoopLikeOpInterface>(op))
409 if (
auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
410 resolveUnrealizedConversionCastOp(castOp);
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static llvm::ManagedStatic< PassManagerOptions > options
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
OpOperand & getOpOperand(unsigned idx)
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ArrayRef< T > asArrayRef() const
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
bool recoverTemporaryLayouts(Operation *rootOp)
Attach layout attributes to all vector-type operands of operations within the given operation's neste...
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...