MLIR  19.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1 //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/TypeUtilities.h"
13 
14 #include "llvm/Support/Debug.h"
15 
16 #define DEBUG_TYPE "xegpu"
17 
18 namespace mlir {
19 namespace xegpu {
20 
22  SmallVector<int64_t> &shape) {
23  SmallVector<int64_t> old = shape;
24  for (size_t i = 0; i < trans.size(); i++)
25  shape[i] = old[trans[i]];
26 }
27 
28 template <typename T>
29 static std::string makeString(T array, bool breakline = false) {
30  std::string buf;
31  buf.clear();
32  llvm::raw_string_ostream os(buf);
33  os << "[";
34  for (size_t i = 1; i < array.size(); i++) {
35  os << array[i - 1] << ", ";
36  if (breakline)
37  os << "\n\t\t";
38  }
39  os << array.back() << "]";
40  os.flush();
41  return buf;
42 }
43 
46  if (auto ty = llvm::dyn_cast<ShapedType>(type))
47  shape = SmallVector<int64_t>(ty.getShape());
48  else
49  shape.push_back(1);
50  return shape;
51 }
52 
53 static int64_t getRankOf(Value val) {
54  auto type = val.getType();
55  if (auto ty = llvm::dyn_cast<ShapedType>(type))
56  return ty.getRank();
57  return 0;
58 }
59 
60 static bool isReadHintOrNone(const CachePolicyAttr &attr) {
61  if (!attr)
62  return true;
63  auto kind = attr.getValue();
64  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65  kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
66 }
67 
68 static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
69  if (!attr)
70  return true;
71  auto kind = attr.getValue();
72  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73  kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // XeGPU_CreateNdDescOp
78 //===----------------------------------------------------------------------===//
79 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
80  Type tdesc, TypedValue<MemRefType> source,
82  [[maybe_unused]] auto ty = source.getType();
83  assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
84 
85  llvm::SmallVector<int64_t> staticOffsets;
86  llvm::SmallVector<Value> dynamicOffsets;
87  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
88 
89  build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
90  ValueRange({}) /* empty dynamic shape */,
91  ValueRange({}) /* empty dynamic strides */,
92  staticOffsets /* const offsets */, {} /* empty const shape*/,
93  {} /* empty const strides*/);
94 }
95 
96 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
97  Type tdesc, TypedValue<IntegerType> source,
101  assert(shape.size() && offsets.size() && strides.size() &&
102  shape.size() == strides.size() && shape.size() == offsets.size());
103 
104  llvm::SmallVector<int64_t> staticOffsets;
105  llvm::SmallVector<int64_t> staticShape;
106  llvm::SmallVector<int64_t> staticStrides;
107  llvm::SmallVector<Value> dynamicOffsets;
108  llvm::SmallVector<Value> dynamicShape;
109  llvm::SmallVector<Value> dynamicStrides;
110 
111  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
112  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
113  dispatchIndexOpFoldResults(strides, dynamicStrides, staticOffsets);
114 
115  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
116  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
117  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
118 
119  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
120  dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
121 }
122 
123 LogicalResult CreateNdDescOp::verify() {
124  auto rank = (int64_t)getMixedOffsets().size();
125  bool invalidRank = (rank != 2);
126  bool invalidElemTy = false;
127 
128  // check source type matches the rank if it is a memref.
129  // It also should have the same ElementType as TensorDesc.
130  auto memrefTy = dyn_cast<MemRefType>(getSourceType());
131  if (memrefTy) {
132  invalidRank |= (memrefTy.getRank() != rank);
133  invalidElemTy |= memrefTy.getElementType() != getElementType();
134  }
135 
136  // check result type matches the rank
137  invalidRank = (getType().getRank() != rank);
138 
139  // mismatches among shape, strides, and offsets are
140  // already handeled by OffsetSizeAndStrideOpInterface.
141  // So they are not check here.
142  if (invalidRank)
143  return emitOpError(
144  "Expecting the rank of shape, strides, offsets, "
145  "source memref type (if source is a memref) and TensorDesc "
146  "should match with each other. They currenlty are 2D.");
147 
148  if (invalidElemTy)
149  return emitOpError("TensorDesc should have the same element "
150  "type with the source if it is a memref.\n");
151 
152  if (getType().getScattered())
153  return emitOpError("Expects a non-scattered TensorDesc.\n");
154 
155  return success();
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // XeGPU_PrefetchNdOp
160 //===----------------------------------------------------------------------===//
161 LogicalResult PrefetchNdOp::verify() {
162  auto tdescTy = getTensorDescType();
163  if (tdescTy.getScattered())
164  return emitOpError("Expects a non-scattered TensorDesc.\n");
165 
166  if (!isReadHintOrNone(getL1HintAttr()))
167  return emitOpError("invlid l1_hint: ") << getL1HintAttr();
168 
169  if (!isReadHintOrNone(getL2HintAttr()))
170  return emitOpError("invlid l2_hint: ") << getL2HintAttr();
171 
172  if (!isReadHintOrNone(getL3HintAttr()))
173  return emitOpError("invlid l3_hint: ") << getL3HintAttr();
174 
175  return success();
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // XeGPU_LoadNdOp
180 //===----------------------------------------------------------------------===//
181 LogicalResult LoadNdOp::verify() {
182  auto tdescTy = getTensorDescType();
183  auto valueTy = getType();
184 
185  if (tdescTy.getRank() != 2)
186  return emitOpError("Expecting a 2D TensorDesc.\n");
187 
188  if (tdescTy.getScattered())
189  return emitOpError("Expects a non-scattered TensorDesc.\n");
190 
191  if (!valueTy)
192  return emitOpError("Invalid result, it should be a VectorType.\n");
193 
194  if (!isReadHintOrNone(getL1HintAttr()))
195  return emitOpError("invlid l1_hint: ") << getL1HintAttr();
196 
197  if (!isReadHintOrNone(getL2HintAttr()))
198  return emitOpError("invlid l2_hint: ") << getL2HintAttr();
199 
200  if (!isReadHintOrNone(getL3HintAttr()))
201  return emitOpError("invlid l3_hint: ") << getL3HintAttr();
202 
203  auto array_len = tdescTy.getArrayLength();
204  auto tdescShape = getShapeOf(tdescTy);
205  auto valueShape = getShapeOf(valueTy);
206 
207  if (getTranspose()) {
208  auto trans = getTranspose().value();
209  if (tdescShape.size() >= trans.size())
210  transpose(trans, tdescShape);
211  else
212  emitWarning("Invalid transpose attr. It is ignored.");
213  }
214 
215  if (getVnniAxis()) {
216  auto axis = getVnniAxis().value();
217  auto vnni_factor = valueShape.back();
218  tdescShape[axis] /= vnni_factor;
219  tdescShape.push_back(vnni_factor);
220  }
221 
222  if (array_len > 1) {
223  auto it = tdescShape.begin();
224  tdescShape.insert(it, array_len);
225  }
226 
227  if (tdescShape != valueShape)
228  return emitOpError() << "Result shape doesn't match TensorDesc shape."
229  << "The expected shape is " << makeString(tdescShape)
230  << ". But the given shape is "
231  << makeString(valueShape) << ".\n";
232  return success();
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // XeGPU_StoreNdOp
237 //===----------------------------------------------------------------------===//
238 LogicalResult StoreNdOp::verify() {
239  auto dstTy = getTensorDescType(); // Tile
240  auto valTy = getValueType(); // Vector
241 
242  if (dstTy.getRank() != 2)
243  return emitOpError("Expecting a 2D TensorDesc.\n");
244 
245  if (dstTy.getScattered())
246  return emitOpError("Expects a non-scattered TensorDesc.\n");
247 
248  if (!valTy)
249  return emitOpError("Exepcting a VectorType result.\n");
250 
251  if (!isWriteHintOrNone(getL1HintAttr()))
252  return emitOpError("invlid l1_hint: ") << getL1HintAttr();
253 
254  if (!isWriteHintOrNone(getL2HintAttr()))
255  return emitOpError("invlid l2_hint: ") << getL2HintAttr();
256 
257  if (!isWriteHintOrNone(getL3HintAttr()))
258  return emitOpError("invlid l3_hint: ") << getL3HintAttr();
259 
260  return success();
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // XeGPU_UpdateNDOffsetOp
265 //===----------------------------------------------------------------------===//
266 LogicalResult UpdateNdOffsetOp::verify() {
267  auto ty = getTensorDescType();
268  if (ty.getScattered())
269  return emitOpError("Expects a non-scattered TensorDesc.\n");
270 
271  // number of offsets specified must match the rank of the tensor descriptor
272  if (ty.getRank() != (int64_t)getNumOffsets()) {
273  return emitOpError("Invalid number of offsets.");
274  }
275  return success();
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // XeGPU_CreateDescOp
280 //===----------------------------------------------------------------------===//
281 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
282  TensorDescType TensorDesc, Value source,
284  uint32_t chunk_size) {
285  llvm::SmallVector<int64_t> staticOffsets;
286  llvm::SmallVector<Value> dynamicOffsets;
287  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
288  build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
289  chunk_size);
290 }
291 
292 LogicalResult CreateDescOp::verify() {
293  auto tdescTy = getTensorDescType();
294  auto chunkSize = getChunkSize();
295 
296  if (getRankOf(getSource()) > 1)
297  return emitOpError(
298  "Expecting the source is a 1D memref or pointer (uint64_t).");
299 
300  if (!tdescTy.getScattered())
301  return emitOpError("Expects a scattered TensorDesc.\n");
302 
303  SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
304  if (chunkSize != 1)
305  shape.push_back(chunkSize);
306 
307  auto tdescShape = getShapeOf(tdescTy);
308  if (shape != tdescShape)
309  return emitOpError("Incorrect TensorDesc shape. ")
310  << "Expected is " << makeString(shape) << "\n";
311 
312  return success();
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // XeGPU_PrefetchOp
317 //===----------------------------------------------------------------------===//
318 LogicalResult PrefetchOp::verify() {
319  auto tdescTy = getTensorDescType();
320  if (!tdescTy.getScattered())
321  return emitOpError("Expects a scattered TensorDesc.\n");
322 
323  if (!isReadHintOrNone(getL1HintAttr()))
324  return emitOpError("invlid l1_hint: ") << getL1HintAttr();
325 
326  if (!isReadHintOrNone(getL2HintAttr()))
327  return emitOpError("invlid l2_hint: ") << getL2HintAttr();
328 
329  if (!isReadHintOrNone(getL3HintAttr()))
330  return emitOpError("invlid l3_hint: ") << getL3HintAttr();
331 
332  return success();
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // XeGPU_LoadGatherOp
337 //===----------------------------------------------------------------------===//
338 LogicalResult LoadGatherOp::verify() {
339  auto tdescTy = getTensorDescType();
340  auto maskTy = getMaskType();
341  auto valueTy = getValueType();
342 
343  if (!tdescTy.getScattered())
344  return emitOpError("Expects a scattered TensorDesc.\n");
345 
346  if (!isReadHintOrNone(getL1HintAttr()))
347  return emitOpError("invlid l1_hint: ") << getL1HintAttr();
348 
349  if (!isReadHintOrNone(getL2HintAttr()))
350  return emitOpError("invlid l2_hint: ") << getL2HintAttr();
351 
352  if (!isReadHintOrNone(getL3HintAttr()))
353  return emitOpError("invlid l3_hint: ") << getL3HintAttr();
354 
355  auto tdescElemTy = tdescTy.getElementType();
356  auto valueElemTy = getElementType();
357  if (tdescElemTy != valueElemTy)
358  return emitOpError(
359  "Value should have the same element type as TensorDesc.");
360 
361  auto maskShape = getShapeOf(maskTy);
362  auto valueShape = getShapeOf(valueTy);
363  auto tdescShape = getShapeOf(tdescTy);
364 
365  if (tdescShape[0] != maskShape[0])
366  return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
367 
368  if (getTransposeAttr()) {
369  auto trans = getTranspose().value();
370  if (tdescShape.size() < trans.size())
371  emitWarning("Invalid transpose attr. It is ignored.");
372  else
373  transpose(trans, tdescShape);
374  }
375 
376  if (valueShape != tdescShape)
377  return emitOpError("Unexpected result shape")
378  << "(Expected shape: " << makeString(tdescShape)
379  << ", Given shape: " << makeString(valueShape) << ").\n";
380 
381  return success();
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // XeGPU_StoreScatterOp
386 //===----------------------------------------------------------------------===//
387 LogicalResult StoreScatterOp::verify() {
388  auto tdescTy = getTensorDescType();
389  if (!tdescTy.getScattered())
390  return emitOpError("Expects a scattered TensorDesc.\n");
391 
392  if (!isWriteHintOrNone(getL1HintAttr()))
393  return emitOpError("invlid l1_hint: ") << getL1HintAttr();
394 
395  if (!isWriteHintOrNone(getL2HintAttr()))
396  return emitOpError("invlid l2_hint: ") << getL2HintAttr();
397 
398  if (!isWriteHintOrNone(getL3HintAttr()))
399  return emitOpError("invlid l3_hint: ") << getL3HintAttr();
400 
401  auto maskTy = getMaskType();
402  auto maskShape = getShapeOf(maskTy);
403  auto tdescShape = getShapeOf(tdescTy);
404  if (tdescShape[0] != maskShape[0])
405  return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
406 
407  return success();
408 }
409 //===----------------------------------------------------------------------===//
410 // XeGPU_DpasOp
411 //===----------------------------------------------------------------------===//
412 LogicalResult DpasOp::verify() {
413  int64_t lhsRank = getLhsType().getRank();
414  int64_t rhsRank = getRhsType().getRank();
415 
416  if (lhsRank != rhsRank || lhsRank != 3)
417  return emitOpError(
418  "lhs and rhs rank does not match for dpas op, or their rank is not 3.");
419 
420  if (getAcc() && getAccType() != getResultType())
421  return emitOpError("Accumulator and Result for dpas op should have the "
422  "same type (both shape and element type).");
423 
424  auto lhsShape = getLhsType().getShape();
425  auto rhsShape = getRhsType().getShape();
426  if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
427  return emitOpError("K-dimension or vnni-factor mismatch.");
428 
429  return success();
430 }
431 
432 } // namespace xegpu
433 } // namespace mlir
434 
435 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
436 #define GET_OP_CLASSES
437 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
This class helps build Operations.
Definition: Builders.h:209
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static std::string makeString(T array, bool breakline=false)
Definition: XeGPUOps.cpp:29
static int64_t getRankOf(Value val)
Definition: XeGPUOps.cpp:53
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:68
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:60
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
static SmallVector< int64_t > getShapeOf(Type type)
Definition: XeGPUOps.cpp:44
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
This represents an operation in an abstracted form, suitable for use with the builder APIs.