13 #include "llvm/ADT/TypeSwitch.h"
21 void XeGPUDialect::initialize() {
23 #define GET_TYPEDEF_LIST
24 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
28 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
31 #define GET_ATTRDEF_LIST
32 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
39 xegpu::LayoutAttr attr) {
40 assert(attr &&
"Layout attribute is missing.");
53 bool rr =
true) -> optional<SmallVector<int64_t>> {
56 auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
57 if (vec.size() != shape.size())
60 if (!ratio.has_value())
62 newShape = ratio.value();
66 auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
67 if (vec.size() != shape.size())
70 if (!ratio.has_value() && rr)
72 if (!ratio.has_value())
83 tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
86 auto sgShape = maybeSgShape.value();
90 tryDistribute(sgShape,
nullptr, attr.getInstData(),
false);
93 auto instShape = maybeInstShape.value();
97 tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(),
false);
98 return maybeLaneShape.has_value();
105 xegpu::MemorySpace memory_space,
107 bool boundary_check) {
112 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
118 ScatterTensorDescAttr
120 xegpu::MemorySpace memory_space,
int chunk_size) {
124 return Base::get(context, scopeAttr, chunkSizeAttr);
129 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
130 int64_t chunkSize = chunk_size.getInt();
131 SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
132 16, 32, 64, 128, 256};
133 if (!llvm::is_contained(supportedChunkSizes, chunkSize))
134 return emitError() <<
"invalid chunk size";
151 if (!sg_layout && !inst_data && !lane_layout) {
153 <<
"expected at least one of sg_layout, inst_data or lane_layout";
159 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
161 <<
"expected sg_layout and inst_data to have the same rank";
164 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
166 <<
"expected sg_layout and lane_layout to have the same rank";
169 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
171 <<
"expected inst_data and lane_layout to have the same rank";
178 return emitError() <<
"expected sg_layout being used with sg_data";
179 if (sg_data.size() != sg_layout.size())
181 <<
"expected sg_data and sg_layout to have the same rank";
188 return emitError() <<
"expected lane_layout being used with lane_data";
189 if (lane_data.size() != lane_layout.size())
191 <<
"expected lane_data and lane_layout to have the same rank";
195 if (!sg_layout && !lane_layout)
197 <<
"expected sg_layout/lane_layout being used with order";
199 if (sg_layout && order.size() != sg_layout.size())
201 <<
"expected order and sg_layout to have the same rank";
203 if (lane_layout && order.size() != lane_layout.size())
205 <<
"expected order and lane_layout to have the same rank";
218 mlir::FailureOr<mlir::Attribute> encoding;
219 mlir::FailureOr<mlir::Attribute> layout;
227 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
232 if (mlir::failed(parser.
parseType(elementType))) {
233 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
241 if (mlir::succeeded(res)) {
242 if (mlir::isa<LayoutAttr>(attr)) {
246 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
258 return TensorDescType::getChecked(
268 for (int64_t dim : shape) {
269 if (mlir::ShapedType::isDynamic(dim))
278 if (
auto encoding = getEncoding())
279 printer <<
", " << encoding;
281 if (
auto layout = getLayout())
282 printer <<
", " << layout;
290 MemorySpace memory_space,
295 return Base::get(context, shape, elementType, attr, layout);
300 MemorySpace memory_space,
304 return Base::get(context, shape, elementType, attr, layout);
311 size_t rank = shape.size();
314 if (rank != 1 && rank != 2)
315 return emitError() <<
"expected 1D or 2D tensor";
317 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
322 unsigned chunkSize = scatterAttr.getChunkSize().getInt();
323 if (rank == 1 && chunkSize != 1)
324 return emitError() <<
"expected non-contiguous elements for 1D tensor";
325 if (rank == 2 && chunkSize < 2)
326 return emitError() <<
"expected chunk blocks for 2D tensor";
330 if (shape.back() != chunkSize)
331 return emitError() <<
"expected tensor shape[1] to match chunk size";
332 if (shape.back() % packingFactor != 0)
334 <<
"expected tensor shape[1] to be a multiple of packing factor "
339 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
341 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
342 if (rank == 2 && memorySpaceAttr &&
343 memorySpaceAttr.getValue() == MemorySpace::SLM)
344 return emitError() <<
"SLM is not supported for 2D block tensor";
347 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
349 if (rank != (
size_t)layoutAttr.getRank())
350 return emitError() <<
"expected layout rank to match tensor rank";
352 auto laneData = layoutAttr.getLaneData();
353 if (scatterAttr && laneData) {
360 if (rank > 1 && laneData[0] != 1)
362 <<
"cannot map over non-contiguous scattered row elements";
363 if (laneData[rank - 1] != packingFactor)
364 return emitError() <<
"work item data mapping must match the number of "
365 "contiguous elements";
368 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
369 std::string shapeStr;
370 llvm::raw_string_ostream stream(shapeStr);
371 llvm::interleaveComma(shape, stream);
372 return emitError() <<
"cannot distribute [" << shapeStr <<
"] using "
382 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
383 #define GET_ATTRDEF_CLASSES
384 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
385 #define GET_TYPEDEF_CLASSES
386 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...