25 #include "llvm/Support/FormatVariadic.h"
34 for (
const auto &vals : values)
35 llvm::append_range(result, vals);
41 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
44 if (!layout || !layout.isForSubgroup())
49 auto tdescShape = tdescTy.getShape();
50 auto elementType = tdescTy.getElementType();
55 auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,
56 std::multiplies<int64_t>());
59 auto scatterAttr = tdescTy.getEncodingOfType<ScatterTensorDescAttr>();
61 auto chunkSize = scatterAttr.getChunkSize().getInt();
64 assert(tdescShape[0] == laneLayout[0] &&
65 "tensor descriptor shape is not distributable");
71 int64_t tensorSize = 1;
72 for (
auto [tdescDim, laneDim, laneDataDim] :
73 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
74 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
75 "tensor descriptor shape is not distributable");
76 tensorSize *= tdescDim;
79 tensorSize *= tdescTy.getArrayLength();
86 xegpu::LayoutAttr layout) {
87 int64_t rank = originalType.getRank();
89 if (rank < 1 || rank > 3)
96 arrayLength = shape[0];
97 shape = shape.drop_front();
100 shape, originalType.getElementType(), arrayLength,
102 xegpu::MemorySpace::Global, layout);
107 const StringRef prefix(
"layout_operand_");
108 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
109 return llvm::formatv(
"{0}{1}", prefix, idx).str();
113 const StringRef prefix =
"layout_result_";
114 return llvm::formatv(
"{0}{1}", prefix, result.
getResultNumber()).str();
122 dyn_cast_if_present<xegpu::TensorDescType>(value.
getType()))
123 return tdescTy.getLayoutAttr();
125 if (
auto result = dyn_cast<OpResult>(value)) {
126 Operation *defOp = result.getDefiningOp();
127 assert(defOp &&
"result must have a defining op");
130 if (
auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(defOp))
131 return convertOp.getTargetLayoutAttr();
134 if (
auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
138 if (defOp->
hasAttr(layoutName))
139 return defOp->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
142 if (
auto arg = dyn_cast<BlockArgument>(value)) {
143 auto parentOp = arg.getOwner()->getParentOp();
144 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
145 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
154 xegpu::DistributeLayoutAttr
159 return op->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
163 template <
typename T,
typename>
165 const DistributeLayoutAttr layout) {
166 Operation *owner = operandOrResult.getOwner();
168 if (layout && !owner->
hasAttrOfType<DistributeLayoutAttr>(name))
173 template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
175 const mlir::xegpu::DistributeLayoutAttr layout);
178 template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
180 const mlir::xegpu::DistributeLayoutAttr layout);
186 auto layout = getLayoutImpl(opr.get());
187 setDistributeLayoutAttr(opr, layout);
190 auto layout = getLayoutImpl(result);
191 setDistributeLayoutAttr(result, layout);
196 template <
typename T,
typename>
198 Operation *owner = operandOrResult.getOwner();
206 xegpu::removeLayoutAttr<mlir::OpResult>(
const mlir::OpResult &result);
210 xegpu::removeLayoutAttr<mlir::OpOperand>(
const mlir::OpOperand &operand);
224 auto vecTy = dyn_cast<VectorType>(value.
getType());
235 result.push_back(vector::ExtractStridedSliceOp::create(
236 builder, loc, value, offsets, shape, staticStrides));
245 VectorType inputTy = dyn_cast<VectorType>(values[0].
getType());
246 assert(llvm::all_of(values.
getTypes(),
247 [&](
Type type) { return type == inputTy; }) &&
248 "values must be of the same VectorType");
250 Type elemTy = inputTy.getElementType();
255 Value result = arith::ConstantOp::create(
258 for (
auto [src, offsets] :
261 result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
262 offsets, staticStrides);
273 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
287 target.
addLegalOp<UnrealizedConversionCastOp>();
298 op->
walk([](UnrealizedConversionCastOp castOp) {
299 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
302 Value input = castOp.getInputs()[0];
303 Value result = castOp.getResults()[0];
304 auto inputTy = dyn_cast<VectorType>(input.
getType());
305 auto resultTy = dyn_cast<RankedTensorType>(result.
getType());
308 if (!inputTy || !resultTy)
311 xegpu::DistributeLayoutAttr layout =
316 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
321 if (
auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
327 if (
auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
328 unsigned idx = use.getOperandNumber();
337 op->
walk([](scf::YieldOp yieldOp) {
340 unsigned idx = r.getResultNumber();
341 Type resultTy = r.getType();
342 Type yieldTy = yieldOp.getResults()[idx].getType();
343 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
356 class UnrealizedConversionCastOpPattern
362 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
363 OneToNOpAdaptor adaptor,
365 auto inputs = op.getOperands();
366 auto outputs = op.getOutputs();
368 if (inputs.size() != 1 || outputs.size() != 1)
371 auto inputTy = inputs[0].getType();
372 auto outputTy = outputs[0].getType();
374 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
379 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
381 auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(),
393 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
399 [](UnrealizedConversionCastOp op) {
400 auto isTensorTy = [](
Type type) {
401 return isa<RankedTensorType>(type);
407 patterns.insert<UnrealizedConversionCastOpPattern>(context);
420 auto targetAttrs = gpuModuleOp.getTargets();
422 for (
auto &attr : *targetAttrs) {
423 auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
425 return xevmAttr.getChip().str();
450 a = a.slice(a.size() - b.size());
451 for (
auto [l, r] : llvm::zip(a, b)) {
454 results.push_back(builder.
createOrFold<index::AddOp>(loc, lval, rval));
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void setDistributeLayoutAttrs(Operation *op, function_ref< DistributeLayoutAttr(Value)> getLayoutImpl)
Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void removeLayoutAttr(const T &operandOrResult)
Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
void removeLayoutAttrs(Operation *op)
Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given operation if they exist...
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.