25 #include "llvm/Support/FormatVariadic.h"
34 for (
const auto &vals : values)
35 llvm::append_range(result, vals);
41 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
44 if (!layout || !layout.isForSubgroup())
49 auto tdescShape = tdescTy.getShape();
50 auto elementType = tdescTy.getElementType();
55 int64_t sgSize = llvm::product_of(laneLayout);
58 auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
60 auto chunkSize = scatterAttr.getChunkSize().getInt();
63 assert(tdescShape[0] == laneLayout[0] &&
64 "tensor descriptor shape is not distributable");
70 int64_t tensorSize = 1;
71 for (
auto [tdescDim, laneDim, laneDataDim] :
72 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
73 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
74 "tensor descriptor shape is not distributable");
75 tensorSize *= tdescDim;
78 tensorSize *= tdescTy.getArrayLength();
85 xegpu::LayoutAttr layout) {
86 int64_t rank = originalType.getRank();
88 if (rank < 1 || rank > 3)
95 arrayLength = shape[0];
96 shape = shape.drop_front();
99 shape, originalType.getElementType(), arrayLength,
101 xegpu::MemorySpace::Global, layout);
106 const StringRef prefix(
"layout_operand_");
107 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
108 return llvm::formatv(
"{0}{1}", prefix, idx).str();
112 const StringRef prefix =
"layout_result_";
113 return llvm::formatv(
"{0}{1}", prefix, result.
getResultNumber()).str();
121 dyn_cast_if_present<xegpu::TensorDescType>(value.
getType()))
122 return tdescTy.getLayoutAttr();
124 if (
auto result = dyn_cast<OpResult>(value)) {
125 Operation *defOp = result.getDefiningOp();
126 assert(defOp &&
"result must have a defining op");
129 if (
auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(defOp))
130 return convertOp.getTargetLayoutAttr();
133 if (
auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
137 if (
auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(defOp))
138 return loadOp.getLayoutAttr();
141 if (
auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
142 return storeOp.getLayoutAttr();
145 if (defOp->
hasAttr(layoutName))
146 return defOp->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
149 if (
auto arg = dyn_cast<BlockArgument>(value)) {
150 auto *parentOp = arg.getOwner()->getParentOp();
151 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
152 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
161 xegpu::DistributeLayoutAttr
165 if (
auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(op))
166 return loadOp.getLayoutAttr();
168 if (
auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(op))
169 return storeOp.getLayoutAttr();
173 return op->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
177 template <
typename T,
typename>
179 const DistributeLayoutAttr layout) {
180 Operation *owner = operandOrResult.getOwner();
182 if (layout && !owner->
hasAttrOfType<DistributeLayoutAttr>(name))
187 template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
189 const mlir::xegpu::DistributeLayoutAttr layout);
192 template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
194 const mlir::xegpu::DistributeLayoutAttr layout);
199 if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(nestOp))
203 auto layout = getLayoutImpl(opr.get());
204 setDistributeLayoutAttr(opr, layout);
207 auto layout = getLayoutImpl(result);
208 setDistributeLayoutAttr(result, layout);
213 template <
typename T,
typename>
215 Operation *owner = operandOrResult.getOwner();
223 xegpu::removeLayoutAttr<mlir::OpResult>(
const mlir::OpResult &result);
227 xegpu::removeLayoutAttr<mlir::OpOperand>(
const mlir::OpOperand &operand);
241 auto vecTy = dyn_cast<VectorType>(value.
getType());
249 int64_t srcShapeRank = srcShape.size();
250 int64_t targetShapeRank = shape.size();
253 int64_t rankDiff = srcShapeRank - targetShapeRank;
254 std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
256 std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
262 Value slice = vector::ExtractStridedSliceOp::create(
263 builder, loc, value, offsets, adjustedTargetShape, staticStrides);
266 if (srcShapeRank > targetShapeRank) {
268 slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
270 result.push_back(slice);
279 VectorType inputTy = dyn_cast<VectorType>(values[0].
getType());
280 assert(llvm::all_of(values.
getTypes(),
281 [&](
Type type) { return type == inputTy; }) &&
282 "values must be of the same VectorType");
284 Type elemTy = inputTy.getElementType();
289 Value result = arith::ConstantOp::create(
292 for (
auto [src, offsets] :
295 result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
296 offsets, staticStrides);
307 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
321 target.
addLegalOp<UnrealizedConversionCastOp>();
332 op->
walk([](UnrealizedConversionCastOp castOp) {
333 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
336 Value input = castOp.getInputs()[0];
337 Value result = castOp.getResults()[0];
338 auto inputTy = dyn_cast<VectorType>(input.
getType());
339 auto resultTy = dyn_cast<RankedTensorType>(result.
getType());
342 if (!inputTy || !resultTy)
345 xegpu::DistributeLayoutAttr layout =
350 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
355 if (
auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
361 if (
auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
362 unsigned idx = use.getOperandNumber();
371 op->
walk([](scf::YieldOp yieldOp) {
374 unsigned idx = r.getResultNumber();
375 Type resultTy = r.getType();
376 Type yieldTy = yieldOp.getResults()[idx].getType();
377 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
390 class UnrealizedConversionCastOpPattern
396 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
397 OneToNOpAdaptor adaptor,
399 auto inputs = op.getOperands();
400 auto outputs = op.getOutputs();
402 if (inputs.size() != 1 || outputs.size() != 1)
405 auto inputTy = inputs[0].getType();
406 auto outputTy = outputs[0].getType();
408 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
413 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
415 auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(),
427 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
433 [](UnrealizedConversionCastOp op) {
434 auto isTensorTy = [](
Type type) {
435 return isa<RankedTensorType>(type);
441 patterns.insert<UnrealizedConversionCastOpPattern>(context);
454 auto targetAttrs = gpuModuleOp.getTargets();
456 for (
auto &attr : *targetAttrs) {
457 auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
459 return xevmAttr.getChip().str();
471 assert(lhs.size() == rhs.size() &&
"lhs and rhs must have the same size");
473 for (
auto [l, r] : llvm::zip_equal(lhs, rhs)) {
476 results.push_back(builder.
createOrFold<index::AddOp>(loc, lval, rval));
499 a = a.slice(a.size() - b.size());
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void setDistributeLayoutAttrs(Operation *op, function_ref< DistributeLayoutAttr(Value)> getLayoutImpl)
Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.