MLIR  21.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1 //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "llvm/ADT/TypeSwitch.h"
14 #include <numeric>
15 
16 using std::optional;
17 
18 namespace mlir {
19 namespace xegpu {
20 
21 void XeGPUDialect::initialize() {
22  addTypes<
23 #define GET_TYPEDEF_LIST
24 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
25  >();
26  addOperations<
27 #define GET_OP_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
29  >();
30  addAttributes<
31 #define GET_ATTRDEF_LIST
32 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
33  >();
34 }
35 
36 // Checks if the given shape can be evenly distributed based on the layout
37 // and data factors provided by the LayoutAttr.
38 bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
39  xegpu::LayoutAttr attr) {
40  assert(attr && "Layout attribute is missing.");
41 
42  // Checks whether the given shape can be evenly distributed using the
43  // specified layout and data attributes. If successful, it returns the work
44  // size for each compute unit; otherwise, it returns `std::nullopt`. The work
45  // size per compute unit is calculated as follows:
46  // - If `data` is null: newShape[i] = shape[i] / layout[i]
47  // - If `data` is not null: newShape[i] = data[i]
48  // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
49  // smaller than `layout[i] * data[i]`, allowing multiple compute units to
50  // share the data.
51  auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
53  bool rr = true) -> optional<SmallVector<int64_t>> {
54  llvm::SmallVector<int64_t> newShape(shape);
55  if (layout) {
56  auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
57  if (vec.size() != shape.size())
58  return std::nullopt;
59  auto ratio = computeShapeRatio(shape, vec);
60  if (!ratio.has_value())
61  return std::nullopt;
62  newShape = ratio.value();
63  }
64 
65  if (data) {
66  auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
67  if (vec.size() != shape.size())
68  return std::nullopt;
69  auto ratio = computeShapeRatio(newShape, vec);
70  if (!ratio.has_value() && rr)
71  ratio = computeShapeRatio(vec, newShape);
72  if (!ratio.has_value())
73  return std::nullopt;
74 
75  // if data is not null, we always return it for next phase.
76  newShape = vec;
77  }
78  return newShape;
79  };
80 
81  // check the sgLayout and sgData
82  auto maybeSgShape =
83  tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
84  if (!maybeSgShape)
85  return false;
86  auto sgShape = maybeSgShape.value();
87 
88  // check InstData, it neither have layout nor need round-robin
89  auto maybeInstShape =
90  tryDistribute(sgShape, nullptr, attr.getInstData(), false);
91  if (!maybeInstShape)
92  return false;
93  auto instShape = maybeInstShape.value();
94 
95  // check LaneLayout and LaneData
96  auto maybeLaneShape =
97  tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
98  return maybeLaneShape.has_value();
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // XeGPU_BlockTensorDescAttr
103 //===----------------------------------------------------------------------===//
104 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
105  xegpu::MemorySpace memory_space,
106  int array_length,
107  bool boundary_check) {
108  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
109  auto lengthAttr =
110  IntegerAttr::get(IntegerType::get(context, 64), array_length);
111  auto boundaryAttr = BoolAttr::get(context, boundary_check);
112  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // XeGPU_ScatterTensorDescAttr
117 //===----------------------------------------------------------------------===//
118 ScatterTensorDescAttr
120  xegpu::MemorySpace memory_space, int chunk_size) {
121  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
122  auto chunkSizeAttr =
123  IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
124  return Base::get(context, scopeAttr, chunkSizeAttr);
125 }
126 
127 LogicalResult ScatterTensorDescAttr::verify(
129  MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
130  int64_t chunkSize = chunk_size.getInt();
131  SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
132  16, 32, 64, 128, 256};
133  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
134  return emitError() << "invalid chunk size";
135 
136  return success();
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // XeGPU_LayoutAttr
141 //===----------------------------------------------------------------------===//
142 LogicalResult
144  DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
145  DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
146  DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
147 
148  // A valid layout must include at least one of sg_layout and lane_layout.
149  // sg_layout is essential for Workgroup layout, while lane_layout is
150  // required for Subgroup layout.
151  if (!sg_layout && !inst_data && !lane_layout) {
152  return emitError()
153  << "expected at least one of sg_layout, inst_data or lane_layout";
154  }
155 
156  // generate code to check sg_laout, inst_data and lane_layout having the same
157  // rank if they are not null.
158 
159  if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
160  return emitError()
161  << "expected sg_layout and inst_data to have the same rank";
162  }
163 
164  if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
165  return emitError()
166  << "expected sg_layout and lane_layout to have the same rank";
167  }
168 
169  if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
170  return emitError()
171  << "expected inst_data and lane_layout to have the same rank";
172  }
173 
174  // sg_data is optional for Workgroup layout, but its presence requires
175  // sg_layout.
176  if (sg_data) {
177  if (!sg_layout)
178  return emitError() << "expected sg_layout being used with sg_data";
179  if (sg_data.size() != sg_layout.size())
180  return emitError()
181  << "expected sg_data and sg_layout to have the same rank";
182  }
183 
184  // lane_data is optional for Subgroup layout, but its presence requires
185  // lane_layout.
186  if (lane_data) {
187  if (!lane_layout)
188  return emitError() << "expected lane_layout being used with lane_data";
189  if (lane_data.size() != lane_layout.size())
190  return emitError()
191  << "expected lane_data and lane_layout to have the same rank";
192  }
193 
194  if (order) {
195  if (!sg_layout && !lane_layout)
196  return emitError()
197  << "expected sg_layout/lane_layout being used with order";
198 
199  if (sg_layout && order.size() != sg_layout.size())
200  return emitError()
201  << "expected order and sg_layout to have the same rank";
202 
203  if (lane_layout && order.size() != lane_layout.size())
204  return emitError()
205  << "expected order and lane_layout to have the same rank";
206  }
207 
208  return success();
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // XeGPU_TensorDescType
213 //===----------------------------------------------------------------------===//
214 
217  mlir::Type elementType;
218  mlir::FailureOr<mlir::Attribute> encoding;
219  mlir::FailureOr<mlir::Attribute> layout;
220 
221  // Parse literal '<'
222  if (parser.parseLess())
223  return {};
224 
225  auto shapeLoc = parser.getCurrentLocation();
226  if (mlir::failed(parser.parseDimensionList(shape))) {
227  parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
228  return {};
229  }
230 
231  auto elemTypeLoc = parser.getCurrentLocation();
232  if (mlir::failed(parser.parseType(elementType))) {
233  parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
234  return {};
235  }
236 
237  // parse optional attributes
238  while (mlir::succeeded(parser.parseOptionalComma())) {
239  mlir::Attribute attr;
240  ParseResult res = parser.parseAttribute(attr);
241  if (mlir::succeeded(res)) {
242  if (mlir::isa<LayoutAttr>(attr)) {
243  layout = attr;
244  continue;
245  }
246  if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
247  encoding = attr;
248  continue;
249  }
250  }
251  return {};
252  }
253 
254  // Parse literal '>'
255  if (parser.parseGreater())
256  return {};
257 
258  return TensorDescType::getChecked(
259  [&]() { return parser.emitError(parser.getNameLoc()); },
260  parser.getContext(), shape, elementType,
261  encoding.value_or(mlir::Attribute()), layout.value_or(mlir::Attribute()));
262 }
263 
264 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
265  printer << "<";
266 
267  auto shape = getShape();
268  for (int64_t dim : shape) {
269  if (mlir::ShapedType::isDynamic(dim))
270  printer << '?';
271  else
272  printer << dim;
273  printer << 'x';
274  }
275 
276  printer << getElementType();
277 
278  if (auto encoding = getEncoding())
279  printer << ", " << encoding;
280 
281  if (auto layout = getLayout())
282  printer << ", " << layout;
283 
284  printer << ">";
285 }
286 
287 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
288  mlir::Type elementType, int array_length,
289  bool boundary_check,
290  MemorySpace memory_space,
291  mlir::Attribute layout) {
292  auto context = elementType.getContext();
293  auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
294  boundary_check);
295  return Base::get(context, shape, elementType, attr, layout);
296 }
297 
298 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
299  mlir::Type elementType, int chunk_size,
300  MemorySpace memory_space,
301  mlir::Attribute layout) {
302  auto context = elementType.getContext();
303  auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
304  return Base::get(context, shape, elementType, attr, layout);
305 }
306 
307 LogicalResult TensorDescType::verify(
309  llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
310  mlir::Attribute encoding, mlir::Attribute layout) {
311  size_t rank = shape.size();
312  // Low-precision types are packed in 32-bit units.
313  int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth();
314  if (rank != 1 && rank != 2)
315  return emitError() << "expected 1D or 2D tensor";
316 
317  auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
318  if (scatterAttr) {
319  // Expected tensor ranks for scattered data:
320  // - 1D tensor for fully non-contiguous elements (chunk size == 1)
321  // - 2D tensor for scattered blocks (chunk size > 1)
322  unsigned chunkSize = scatterAttr.getChunkSize().getInt();
323  if (rank == 1 && chunkSize != 1)
324  return emitError() << "expected non-contiguous elements for 1D tensor";
325  if (rank == 2 && chunkSize < 2)
326  return emitError() << "expected chunk blocks for 2D tensor";
327  // If chunk size > 1, the second dimension of the tensor shape must be
328  // equal to chunk size and it must be a multiple of the packing factor.
329  if (chunkSize > 1) {
330  if (shape.back() != chunkSize)
331  return emitError() << "expected tensor shape[1] to match chunk size";
332  if (shape.back() % packingFactor != 0)
333  return emitError()
334  << "expected tensor shape[1] to be a multiple of packing factor "
335  << packingFactor;
336  }
337  }
338 
339  auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
340  if (blockAttr) {
341  MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
342  if (rank == 2 && memorySpaceAttr &&
343  memorySpaceAttr.getValue() == MemorySpace::SLM)
344  return emitError() << "SLM is not supported for 2D block tensor";
345  }
346 
347  auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
348  if (layoutAttr) {
349  if (rank != (size_t)layoutAttr.getRank())
350  return emitError() << "expected layout rank to match tensor rank";
351 
352  auto laneData = layoutAttr.getLaneData();
353  if (scatterAttr && laneData) {
354  // Validate subgroup mapping rules for scattered tensors.
355  // A work-item's slice of the tensor with shape [sg_size] or
356  // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
357  // respectively, the mapping should reflect that. This is because each
358  // work item access data in 32 bit granularity.
359 
360  if (rank > 1 && laneData[0] != 1)
361  return emitError()
362  << "cannot map over non-contiguous scattered row elements";
363  if (laneData[rank - 1] != packingFactor)
364  return emitError() << "work item data mapping must match the number of "
365  "contiguous elements";
366  }
367 
368  if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
369  std::string shapeStr;
370  llvm::raw_string_ostream stream(shapeStr);
371  llvm::interleaveComma(shape, stream);
372  return emitError() << "cannot distribute [" << shapeStr << "] using "
373  << layoutAttr;
374  }
375  }
376  return success();
377 }
378 
379 } // namespace xegpu
380 } // namespace mlir
381 
382 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
383 #define GET_ATTRDEF_CLASSES
384 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
385 #define GET_TYPEDEF_CLASSES
386 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424