24#include "llvm/Support/FormatVariadic.h"
33 for (
const auto &vals : values)
34 llvm::append_range(
result, vals);
40 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
43 if (!layout || !layout.isForSubgroup())
48 auto tdescShape = tdescTy.getShape();
49 auto elementType = tdescTy.getElementType();
54 int64_t sgSize = llvm::product_of(laneLayout);
57 auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
59 auto chunkSize = scatterAttr.getChunkSize().getInt();
62 assert(tdescShape[0] == laneLayout[0] &&
63 "tensor descriptor shape is not distributable");
64 return VectorType::get({chunkSize}, elementType);
70 for (
auto [tdescDim, laneDim, laneDataDim] :
71 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
72 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
73 "tensor descriptor shape is not distributable");
74 tensorSize *= tdescDim;
77 tensorSize *= tdescTy.getArrayLength();
79 return VectorType::get({tensorSize / sgSize}, elementType);
84 xegpu::LayoutAttr layout) {
85 int64_t rank = originalType.getRank();
87 if (rank < 1 || rank > 3)
94 arrayLength =
shape[0];
97 auto helperTdescTy = xegpu::TensorDescType::get(
98 shape, originalType.getElementType(), arrayLength,
100 xegpu::MemorySpace::Global, layout);
105 const StringRef prefix(
"layout_operand_");
106 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
107 return llvm::formatv(
"{0}{1}", prefix, idx).str();
111 const StringRef prefix =
"layout_result_";
112 return llvm::formatv(
"{0}{1}", prefix,
result.getResultNumber()).str();
120 dyn_cast_if_present<xegpu::TensorDescType>(value.
getType()))
121 return tdescTy.getLayoutAttr();
123 if (
auto result = dyn_cast<OpResult>(value)) {
125 assert(defOp &&
"result must have a defining op");
128 if (
auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(defOp))
129 return convertOp.getTargetLayoutAttr();
132 if (
auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
136 if (
auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(defOp))
137 return loadOp.getLayoutAttr();
140 if (
auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
141 return storeOp.getLayoutAttr();
143 if (defOp->
hasAttr(layoutName))
144 return defOp->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
148 if (
auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
149 return loadGatherOp.getLayoutAttr();
152 if (
auto arg = dyn_cast<BlockArgument>(value)) {
153 auto *parentOp = arg.getOwner()->getParentOp();
154 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
155 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
164xegpu::DistributeLayoutAttr
168 if (
auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(op))
169 return loadOp.getLayoutAttr();
171 if (
auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(op))
172 return storeOp.getLayoutAttr();
176 return op->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
179 if (
auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
180 if (
auto layout = storeScatterOp.getLayoutAttr())
188xegpu::DistributeLayoutAttr
191 const std::string &name) {
192 xegpu::DistributeLayoutAttr candidate = layout;
194 if (
auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
195 if (
auto perm = loadOp.getLayoutAttr())
204xegpu::DistributeLayoutAttr
207 const std::string &name) {
208 xegpu::DistributeLayoutAttr candidate = layout;
209 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
211 if (
auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
213 if (
auto perm = storeOp.getLayoutAttr())
221template <
typename T,
typename>
223 const DistributeLayoutAttr layout,
224 bool respectPermLayout) {
225 Operation *owner = operandOrResult.getOwner();
231 DistributeLayoutAttr candidate = layout;
232 if (respectPermLayout)
236 owner->
setAttr(name, candidate);
242 const mlir::xegpu::DistributeLayoutAttr layout,
bool respectPermLayout);
247 const mlir::xegpu::DistributeLayoutAttr layout,
bool respectPermLayout);
252 if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(nestOp))
256 auto layout = getLayoutImpl(opr.get());
260 auto layout = getLayoutImpl(
result);
266template <
typename T,
typename>
268 Operation *owner = operandOrResult.getOwner();
294 auto vecTy = dyn_cast<VectorType>(value.
getType());
302 int64_t srcShapeRank = srcShape.size();
306 int64_t rankDiff = srcShapeRank - targetShapeRank;
307 std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
309 llvm::copy(
shape, adjustedTargetShape.begin() + rankDiff);
315 Value slice = vector::ExtractStridedSliceOp::create(
316 builder, loc, value, offsets, adjustedTargetShape, staticStrides);
319 if (srcShapeRank > targetShapeRank) {
320 auto targetTy = VectorType::get(
shape, vecTy.getElementType());
321 slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
332 VectorType inputTy = dyn_cast<VectorType>(values[0].
getType());
333 assert(llvm::all_of(values.
getTypes(),
334 [&](
Type type) { return type == inputTy; }) &&
335 "values must be of the same VectorType");
337 Type elemTy = inputTy.getElementType();
340 VectorType resultTy = VectorType::get(
shape, elemTy);
345 for (
auto [src, offsets] :
348 result = vector::InsertStridedSliceOp::create(builder, loc, src,
result,
349 offsets, staticStrides);
360 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
366 converter.addConversion([](
Type type) ->
Type {
return type; });
367 converter.addConversion([](VectorType type) ->
Type {
368 return RankedTensorType::get(type.getShape(), type.getElementType());
370 converter.addSourceMaterialization(materializeCast);
371 converter.addTargetMaterialization(materializeCast);
373 mlir::ConversionTarget
target(*context);
374 target.addLegalOp<UnrealizedConversionCastOp>();
385 op->
walk([](UnrealizedConversionCastOp castOp) {
386 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
389 Value input = castOp.getInputs()[0];
391 auto inputTy = dyn_cast<VectorType>(input.
getType());
392 auto resultTy = dyn_cast<RankedTensorType>(
result.getType());
395 if (!inputTy || !resultTy)
398 xegpu::DistributeLayoutAttr layout =
403 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
408 if (
auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
414 if (
auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
415 unsigned idx = use.getOperandNumber();
424 op->
walk([](scf::YieldOp yieldOp) {
427 unsigned idx = r.getResultNumber();
428 Type resultTy = r.getType();
429 Type yieldTy = yieldOp.getResults()[idx].getType();
430 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
443 class UnrealizedConversionCastOpPattern
444 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
445 using OpConversionPattern<
446 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
449 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
451 ConversionPatternRewriter &rewriter)
const override {
452 auto inputs = op.getOperands();
453 auto outputs = op.getOutputs();
455 if (inputs.size() != 1 || outputs.size() != 1)
458 auto inputTy = inputs[0].getType();
459 auto outputTy = outputs[0].getType();
461 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
462 rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
466 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
468 auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(),
470 rewriter.replaceOp(op, newOp);
477 converter.addSourceMaterialization(materializeCast);
480 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
484 mlir::ConversionTarget
target(*context);
485 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
486 [](UnrealizedConversionCastOp op) {
487 auto isTensorTy = [](
Type type) {
488 return isa<RankedTensorType>(type);
494 patterns.insert<UnrealizedConversionCastOpPattern>(context);
507 auto targetAttrs = gpuModuleOp.getTargets();
509 for (
auto &attr : *targetAttrs) {
510 auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
512 return xevmAttr.getChip().str();
524 assert(
lhs.size() ==
rhs.size() &&
"lhs and rhs must have the same size");
526 for (
auto [l, r] : llvm::zip_equal(
lhs,
rhs)) {
529 results.push_back(builder.
createOrFold<arith::AddIOp>(loc, lval, rval));
552 a = a.slice(a.size() -
b.size());
560 static_assert(std::is_integral<T>::value,
"T must be an integer type");
563 if (!candidateMultiples.empty())
565 SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
566 for (T candidate : candidates) {
567 for (T multiple : multiples) {
568 int value =
static_cast<int>(candidate * multiple);
569 if (value != 0 && dim % value == 0 && value > largest)
xegpu::DistributeLayoutAttr maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, const OpResult &result, mlir::Operation *owner, const std::string &name)
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
MLIRContext * getContext()
Return the context this operation is associated with.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void setDistributeLayoutAttrs(Operation *op, function_ref< DistributeLayoutAttr(Value)> getLayoutImpl)
Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
llvm::function_ref< Fn > function_ref