22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
32 for (
const auto &vals : values)
33 llvm::append_range(result, vals);
39 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
42 if (!layout || !layout.isSgLayout())
47 auto tdescShape = tdescTy.getShape();
48 auto elementType = tdescTy.getElementType();
53 auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,
54 std::multiplies<int64_t>());
57 auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
59 auto chunkSize = scatterAttr.getChunkSize().getInt();
62 assert(tdescShape[0] == laneLayout[0] &&
63 "tensor descriptor shape is not distributable");
69 int64_t tensorSize = 1;
70 for (
auto [tdescDim, laneDim, laneDataDim] :
71 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
72 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
73 "tensor descriptor shape is not distributable");
74 tensorSize *= tdescDim;
77 tensorSize *= tdescTy.getArrayLength();
84 xegpu::LayoutAttr layout) {
85 int64_t rank = originalType.getRank();
87 if (rank < 1 || rank > 3)
94 arrayLength = shape[0];
95 shape = shape.drop_front();
98 shape, originalType.getElementType(), arrayLength,
100 xegpu::MemorySpace::Global, layout);
105 const StringRef prefix(
"layout_operand_");
106 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
107 return llvm::formatv(
"{0}{1}", prefix, idx).str();
111 const StringRef prefix =
"layout_result_";
112 return llvm::formatv(
"{0}{1}", prefix, result.
getResultNumber()).str();
120 dyn_cast_if_present<xegpu::TensorDescType>(value.
getType()))
121 return tdescTy.getLayoutAttr();
123 if (
auto result = dyn_cast<OpResult>(value)) {
124 Operation *defOp = result.getDefiningOp();
125 assert(defOp &&
"result must have a defining op");
128 if (
auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
132 if (defOp->
hasAttr(layoutName))
136 if (
auto arg = dyn_cast<BlockArgument>(value)) {
137 auto parentOp = arg.getOwner()->getParentOp();
138 if (
auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
139 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
155 template <
typename T,
typename>
157 Operation *owner = operandOrResult.getOwner();
165 xegpu::setLayoutAttr<mlir::OpResult>(
const mlir::OpResult &result,
166 const mlir::xegpu::LayoutAttr layout);
171 const mlir::xegpu::LayoutAttr layout);
177 auto layout = getLayoutImpl(opr.get());
178 setLayoutAttr(opr, layout);
181 auto layout = getLayoutImpl(result);
182 setLayoutAttr(result, layout);
190 auto vecTy = dyn_cast<VectorType>(value.
getType());
201 result.push_back(builder.
create<vector::ExtractStridedSliceOp>(
202 loc, value, offsets, shape, staticStrides));
211 VectorType inputTy = dyn_cast<VectorType>(values[0].
getType());
212 assert(llvm::all_of(values.
getTypes(),
213 [&](
Type type) { return type == inputTy; }) &&
214 "values must be of the same VectorType");
216 Type elemTy = inputTy.getElementType();
224 for (
auto [src, offsets] :
227 result = builder.
create<vector::InsertStridedSliceOp>(
228 loc, src, result, offsets, staticStrides);
239 return builder.
create<UnrealizedConversionCastOp>(loc, type, inputs)
253 target.
addLegalOp<UnrealizedConversionCastOp>();
264 op->
walk([](UnrealizedConversionCastOp castOp) {
265 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
268 Value input = castOp.getInputs()[0];
269 Value result = castOp.getResults()[0];
270 auto inputTy = dyn_cast<VectorType>(input.
getType());
271 auto resultTy = dyn_cast<RankedTensorType>(result.
getType());
274 if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
281 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
286 if (
auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
292 if (
auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
293 unsigned idx = use.getOperandNumber();
302 op->
walk([](scf::YieldOp yieldOp) {
305 unsigned idx = r.getResultNumber();
306 Type resultTy = r.getType();
307 Type yieldTy = yieldOp.getResults()[idx].getType();
308 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
321 class UnrealizedConversionCastOpPattern
327 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
328 OneToNOpAdaptor adaptor,
330 auto inputs = op.getOperands();
331 auto outputs = op.getOutputs();
333 if (inputs.size() != 1 || outputs.size() != 1)
336 auto inputTy = inputs[0].getType();
337 auto outputTy = outputs[0].getType();
339 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
344 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
346 auto newOp = rewriter.
create<UnrealizedConversionCastOp>(
347 op.getLoc(), outputTy, values);
358 return builder.
create<UnrealizedConversionCastOp>(loc, type, inputs)
364 [](UnrealizedConversionCastOp op) {
365 auto isTensorTy = [](
Type type) {
366 return isa<RankedTensorType>(type);
372 patterns.insert<UnrealizedConversionCastOpPattern>(context);
static SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
convert ArrayRef<ValueRange> into SmallVector<Value>
This class represents an argument of a Block.
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
result_range getOpResults()
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
LayoutAttr getLayoutAttr(const Value value)
Retrieves the LayoutAttr associated with a given Value.
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)
Set the LayoutAttr for each OpOperand and OpResult of the given operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.