12 #include "llvm/ADT/TypeSwitch.h"
17 void XeGPUDialect::initialize() {
19 #define GET_TYPEDEF_LIST
20 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
24 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
27 #define GET_ATTRDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
36 xegpu::MemorySpace memory_space,
38 bool boundary_check) {
43 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
51 xegpu::MemorySpace memory_space,
int chunk_size) {
55 return Base::get(context, scopeAttr, chunkSizeAttr);
60 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
61 int64_t chunkSize = chunk_size.getInt();
62 SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
63 16, 32, 64, 128, 256};
64 if (!llvm::is_contained(supportedChunkSizes, chunkSize))
65 return emitError() <<
"invalid chunk size";
74 template <
typename T,
unsigned N>
77 llvm::StringRef fieldName) {
80 "unexpected field name. Expected " + fieldName +
".");
89 auto elemParser = [&]() -> llvm::ParseResult {
92 result.push_back(elem);
97 elemParser, fieldName);
107 if (failed(parseIntArrayField(parser, wi_layout,
"wi_layout")))
113 if (failed(parseIntArrayField(parser, wi_data,
"wi_data")))
116 return SGMapAttr::getChecked(
124 printer <<
" = [" << getWiLayout() <<
"], ";
126 printer <<
" = [" << getWiData() <<
"]";
134 if (wi_layout.size() != 2)
135 return emitError() <<
"expected wi_layout of size 2";
136 if (wi_data.size() != 2)
137 return emitError() <<
"expected wi_data of size 2";
148 mlir::FailureOr<mlir::Attribute> encoding;
149 mlir::FailureOr<mlir::Attribute> sg_map;
157 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
162 if (mlir::failed(parser.
parseType(elementType))) {
163 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
171 if (mlir::succeeded(res)) {
172 if (mlir::isa<SGMapAttr>(attr)) {
176 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
188 return TensorDescType::getChecked(
198 for (int64_t dim : shape) {
199 if (mlir::ShapedType::isDynamic(dim))
208 if (
auto encoding = getEncoding())
209 printer <<
", " << encoding;
211 if (
auto sg_map = getSgMap())
212 printer <<
", " << sg_map;
220 MemorySpace memory_space,
225 return Base::get(context, shape, elementType, attr, sg_map);
230 MemorySpace memory_space,
234 return Base::get(context, shape, elementType, attr, sg_map);
241 size_t rank = shape.size();
244 if (rank != 1 && rank != 2)
245 return emitError() <<
"expected 1D or 2D tensor";
247 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
252 unsigned chunkSize = scatterAttr.getChunkSize().getInt();
253 if (rank == 1 && chunkSize != 1)
254 return emitError() <<
"expected non-contiguous elements for 1D tensor";
255 if (rank == 2 && chunkSize < 2)
256 return emitError() <<
"expected chunk blocks for 2D tensor";
260 if (shape.back() != chunkSize)
261 return emitError() <<
"expected tensor shape[1] to match chunk size";
262 if (shape.back() % packingFactor != 0)
264 <<
"expected tensor shape[1] to be a multiple of packing factor "
270 mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
271 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
272 if (rank == 2 && memorySpaceAttr &&
273 memorySpaceAttr.getValue() == MemorySpace::SLM)
274 return emitError() <<
"SLM is not supported for 2D block tensor";
277 if (
auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
278 ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
279 ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
282 if (wiLayout[0] != 1 || wiData[0] != 1)
284 <<
"outer layout distribution and data mapping must be 1 "
296 <<
"cannot map over non-contiguous scattered row elements";
297 if (wiData[1] != packingFactor)
298 return emitError() <<
"work item data mapping must match the number of "
299 "contiguous elements";
304 SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
306 tensorShape = {1, tensorShape.back()};
308 size_t dims = tensorShape.size();
309 for (
size_t i = 0; i < dims; ++i) {
310 uint32_t numElemPerWi = wiLayout[i] * wiData[i];
311 if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
312 return emitError() <<
"cannot distribute " << tensorShape[i] <<
" over "
313 << wiLayout[i] <<
" work items with " << wiData[i]
345 FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
346 auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
351 SmallVector<int64_t> wiData(sgMap.getWiData());
352 SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
355 auto wiDataSize = 1, sgSize = 1;
356 for (
auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
357 wiDataSize *= wiDataDim;
362 auto scatterAttr = getEncodingAsScatterTensorDescAttr();
364 auto chunkSize = scatterAttr.getChunkSize().getInt();
367 assert(tdescShape[0] % (wiLayout[0]) == 0 &&
368 "tensor descriptor shape is not distributable");
378 if (tdescShape.size() == 1) {
379 assert((wiData[0] == 1 && wiLayout[0] == 1) &&
380 "wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor");
381 wiData = {wiData[1]};
382 wiLayout = {wiLayout[1]};
385 int64_t tensorSize = 1;
386 for (
auto [tdescDim, wiDim, wiDataDim] :
387 llvm::zip_equal(tdescShape, wiLayout, wiData)) {
388 assert((tdescDim % (wiDim * wiDataDim) == 0) &&
389 "tensor descriptor shape is not distributable");
390 tensorSize *= tdescDim;
393 tensorSize *= getArrayLength();
395 return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
402 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
403 #define GET_ATTRDEF_CLASSES
404 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
405 #define GET_TYPEDEF_CLASSES
406 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...