MLIR  21.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1 //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #include "llvm/Support/Debug.h"
16 
17 #define DEBUG_TYPE "xegpu"
18 
19 namespace mlir {
20 namespace xegpu {
21 
23  SmallVector<int64_t> &shape) {
24  SmallVector<int64_t> old = shape;
25  for (size_t i = 0; i < trans.size(); i++)
26  shape[i] = old[trans[i]];
27 }
28 
29 template <typename T>
30 static std::string makeString(T array, bool breakline = false) {
31  std::string buf;
32  buf.clear();
33  llvm::raw_string_ostream os(buf);
34  os << "[";
35  for (size_t i = 1; i < array.size(); i++) {
36  os << array[i - 1] << ", ";
37  if (breakline)
38  os << "\n\t\t";
39  }
40  os << array.back() << "]";
41  return buf;
42 }
43 
46  if (auto ty = llvm::dyn_cast<ShapedType>(type))
47  shape = SmallVector<int64_t>(ty.getShape());
48  else
49  shape.push_back(1);
50  return shape;
51 }
52 
53 static int64_t getRankOf(Value val) {
54  auto type = val.getType();
55  if (auto ty = llvm::dyn_cast<ShapedType>(type))
56  return ty.getRank();
57  return 0;
58 }
59 
60 static bool isReadHintOrNone(const CachePolicyAttr &attr) {
61  if (!attr)
62  return true;
63  auto kind = attr.getValue();
64  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65  kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
66 }
67 
68 static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
69  if (!attr)
70  return true;
71  auto kind = attr.getValue();
72  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73  kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74 }
75 
76 // Validations for nd instruction arguments is successful if any of these are
77 // true:
78 // - tensor descriptor and the output vector shapes exactly match.
79 // - tensor descriptor has a sg_map attribute and the distributed vector shape
80 // matches the tensor descriptor shape when scaled using sg_map factors on
81 // each dimension.
82 static bool isArgShapesValid(ArrayRef<int64_t> descShape,
83  ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84  // Equal shapes with no distribution - no further verification needed.
85  if (descShape == valShape && !sgMap)
86  return true;
87 
88  // Unknown distribution - cannot perform operation on partial shape.
89  if (!sgMap)
90  return false;
91 
92  // Invalid rank or mixed rank usage.
93  size_t descRank = descShape.size();
94  if (descRank > 2 || valShape.size() != descRank)
95  return false;
96 
97  // For 1D, SG map is guaranteed to be unit size in the outer dimension.
98  // Only take the distribution over the innermost dimension for validation.
99  ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
100  SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
101  if (descRank == 1)
102  mapLayout = {wiLayout.back()};
103 
104  for (const auto &[factor, dim, expected] :
105  llvm::zip_equal(mapLayout, valShape, descShape)) {
106  if (factor * dim != expected)
107  return false;
108  }
109 
110  return true;
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // XeGPU_CreateNdDescOp
115 //===----------------------------------------------------------------------===//
116 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
117  Type tdesc, TypedValue<MemRefType> source,
119  [[maybe_unused]] auto ty = source.getType();
120  assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
121 
122  llvm::SmallVector<int64_t> staticOffsets;
123  llvm::SmallVector<Value> dynamicOffsets;
124  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
125 
126  build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
127  ValueRange({}) /* empty dynamic shape */,
128  ValueRange({}) /* empty dynamic strides */,
129  staticOffsets /* const offsets */, {} /* empty const shape*/,
130  {} /* empty const strides*/);
131 }
132 
133 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
134  Type tdesc, TypedValue<MemRefType> source,
138  assert(shape.size() && offsets.size() && strides.size() &&
139  shape.size() == strides.size() && shape.size() == offsets.size());
140 
141  llvm::SmallVector<int64_t> staticOffsets;
142  llvm::SmallVector<int64_t> staticShape;
143  llvm::SmallVector<int64_t> staticStrides;
144  llvm::SmallVector<Value> dynamicOffsets;
145  llvm::SmallVector<Value> dynamicShape;
146  llvm::SmallVector<Value> dynamicStrides;
147 
148  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
149  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
150  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
151 
152  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
153  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
154  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
155 
156  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
157  dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
158 }
159 
160 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
161  Type tdesc, TypedValue<IntegerType> source,
165  assert(shape.size() && offsets.size() && strides.size() &&
166  shape.size() == strides.size() && shape.size() == offsets.size());
167 
168  llvm::SmallVector<int64_t> staticOffsets;
169  llvm::SmallVector<int64_t> staticShape;
170  llvm::SmallVector<int64_t> staticStrides;
171  llvm::SmallVector<Value> dynamicOffsets;
172  llvm::SmallVector<Value> dynamicShape;
173  llvm::SmallVector<Value> dynamicStrides;
174 
175  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
176  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
177  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
178 
179  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
180  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
181  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
182 
183  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
184  dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
185 }
186 
187 LogicalResult CreateNdDescOp::verify() {
188  auto rank = (int64_t)getMixedOffsets().size();
189  bool invalidRank = false;
190  bool invalidElemTy = false;
191 
192  // Memory space of created TensorDesc should match with the source.
193  // Both source and TensorDesc are considered for global memory by default,
194  // if the memory scope attr is not specified. If source is an integer,
195  // it is considered as ptr to global memory.
196  auto srcMemorySpace = getSourceMemorySpace();
197  auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
198  if (srcMemorySpace != tdescMemorySpace)
199  return emitOpError("Memory space mismatch.")
200  << " Source: " << srcMemorySpace
201  << ", TensorDesc: " << tdescMemorySpace;
202 
203  // check source type matches the rank if it is a memref.
204  // It also should have the same ElementType as TensorDesc.
205  auto memrefTy = dyn_cast<MemRefType>(getSourceType());
206  if (memrefTy) {
207  invalidRank |= (memrefTy.getRank() != rank);
208  invalidElemTy |= memrefTy.getElementType() != getElementType();
209  }
210 
211  // mismatches among shape, strides, and offsets are
212  // already handeled by OffsetSizeAndStrideOpInterface.
213  // So they are not check here.
214  if (invalidRank)
215  return emitOpError(
216  "Expecting the rank of shape, strides, offsets, and source (if source "
217  "is a memref) should match with each other.");
218 
219  // check result TensorDesc rank
220  invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
221 
222  if (invalidRank)
223  return emitOpError(
224  "Expecting the TensorDesc rank is up to 2 and not greater than the "
225  "ranks of shape, strides, offsets or the memref source.");
226 
227  if (invalidElemTy)
228  return emitOpError("TensorDesc should have the same element "
229  "type with the source if it is a memref.\n");
230 
231  if (getType().isScattered())
232  return emitOpError("Expects a non-scattered TensorDesc.\n");
233 
234  return success();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // XeGPU_PrefetchNdOp
239 //===----------------------------------------------------------------------===//
240 LogicalResult PrefetchNdOp::verify() {
241  auto tdescTy = getTensorDescType();
242  if (tdescTy.isScattered())
243  return emitOpError("Expects a non-scattered TensorDesc.\n");
244 
245  if (!isReadHintOrNone(getL1HintAttr()))
246  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
247 
248  if (!isReadHintOrNone(getL2HintAttr()))
249  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
250 
251  if (!isReadHintOrNone(getL3HintAttr()))
252  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
253 
254  return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // XeGPU_LoadNdOp
259 //===----------------------------------------------------------------------===//
260 LogicalResult LoadNdOp::verify() {
261  auto tdescTy = getTensorDescType();
262  auto valueTy = getType();
263 
264  if (tdescTy.getRank() > 2)
265  return emitOpError("Expecting a 1D/2D TensorDesc.\n");
266 
267  if (tdescTy.isScattered())
268  return emitOpError("Expects a non-scattered TensorDesc.\n");
269 
270  if (!valueTy)
271  return emitOpError("Invalid result, it should be a VectorType.\n");
272 
273  if (!isReadHintOrNone(getL1HintAttr()))
274  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
275 
276  if (!isReadHintOrNone(getL2HintAttr()))
277  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
278 
279  if (!isReadHintOrNone(getL3HintAttr()))
280  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
281 
282  auto array_len = tdescTy.getArrayLength();
283  auto tdescShape = getShapeOf(tdescTy);
284  auto valueShape = getShapeOf(valueTy);
285 
286  if (getTranspose()) {
287  auto trans = getTranspose().value();
288 
289  // Make sure the transpose value is valid.
290  bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
291  return t >= 0 && t < tdescTy.getRank();
292  });
293 
294  if (valid)
295  transpose(trans, tdescShape);
296  else
297  mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
298  }
299 
300  if (getPacked()) {
301  if (tdescTy.getRank() == 2) {
302  const int axis = 0;
303  auto vnni_factor = valueShape.back();
304  tdescShape[axis] /= vnni_factor;
305  tdescShape.push_back(vnni_factor);
306  } else {
307  mlir::emitWarning(getLoc())
308  << "Invalid Packed Attr. It is ignored (available for 2D "
309  "TensorDesc only).";
310  }
311  }
312 
313  if (array_len > 1) {
314  auto it = tdescShape.begin();
315  tdescShape.insert(it, array_len);
316  }
317  auto sgMap = tdescTy.getSGMapAttr();
318 
319  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
320  return emitOpError() << "Result shape doesn't match TensorDesc shape."
321  << "The expected shape is " << makeString(tdescShape)
322  << ". But the given shape is "
323  << makeString(valueShape) << ".\n";
324  return success();
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // XeGPU_StoreNdOp
329 //===----------------------------------------------------------------------===//
330 LogicalResult StoreNdOp::verify() {
331  auto dstTy = getTensorDescType(); // Tile
332  auto valTy = getValueType(); // Vector
333 
334  if (dstTy.getRank() > 2)
335  return emitOpError("Expecting a 1D/2D TensorDesc.\n");
336 
337  if (dstTy.isScattered())
338  return emitOpError("Expects a non-scattered TensorDesc.\n");
339 
340  if (!valTy)
341  return emitOpError("Expecting a VectorType result.\n");
342 
343  if (!isWriteHintOrNone(getL1HintAttr()))
344  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
345 
346  if (!isWriteHintOrNone(getL2HintAttr()))
347  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
348 
349  if (!isWriteHintOrNone(getL3HintAttr()))
350  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
351 
352  auto tdescShape = getShapeOf(dstTy);
353  auto valueShape = getShapeOf(valTy);
354  auto sgMap = dstTy.getSGMapAttr();
355 
356  if (!isArgShapesValid(tdescShape, valueShape, sgMap))
357  return emitOpError() << "Result shape doesn't match TensorDesc shape."
358  << "The expected shape is " << makeString(tdescShape)
359  << ". But the given shape is "
360  << makeString(valueShape) << ".\n";
361  return success();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // XeGPU_UpdateNDOffsetOp
366 //===----------------------------------------------------------------------===//
367 LogicalResult UpdateNdOffsetOp::verify() {
368  auto ty = getTensorDescType();
369  if (ty.isScattered())
370  return emitOpError("Expects a non-scattered TensorDesc.\n");
371 
372  // number of offsets specified must match the rank of the tensor descriptor
373  if (ty.getRank() != (int64_t)getNumOffsets()) {
374  return emitOpError("Invalid number of offsets.");
375  }
376  return success();
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // XeGPU_CreateDescOp
381 //===----------------------------------------------------------------------===//
382 
383 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
384  TensorDescType TensorDesc, Value source,
386  auto loc = source.getLoc();
387  int64_t size = static_cast<int64_t>(offsets.size());
388  auto type = VectorType::get(size, builder.getIndexType());
389  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
390  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
391  build(builder, state, TensorDesc, source, offset);
392 }
393 
394 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
395  TensorDescType TensorDesc, Value source,
396  llvm::ArrayRef<int64_t> offsets) {
397  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
398  build(builder, state, TensorDesc, source, ofrs);
399 }
400 
401 LogicalResult CreateDescOp::verify() {
402  auto tdescTy = getTensorDescType();
403 
404  if (getRankOf(getSource()) > 1)
405  return emitOpError(
406  "Expecting the source is a 1D memref or pointer (uint64_t).");
407 
408  if (!tdescTy.isScattered())
409  return emitOpError("Expects a scattered TensorDesc.\n");
410 
411  // Memory space of created TensorDesc should match with the source.
412  // Both source and TensorDesc are considered for global memory by default,
413  // if the memory scope attr is not specified. If source is an integer,
414  // it is considered as ptr to global memory.
415  auto srcMemorySpace = getSourceMemorySpace();
416  auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
417  if (srcMemorySpace != tdescMemorySpace)
418  return emitOpError("Memory space mismatch.")
419  << " Source: " << srcMemorySpace
420  << ", TensorDesc: " << tdescMemorySpace;
421 
422  auto chunkSize = tdescTy.getChunkSize();
423 
424  // check chunk_size
425  llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
426  16, 32, 64, 128, 256};
427  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
428  return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
429  "8, 16, 32, 64, 128, or 256.");
430 
431  // check total size
432  auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
433  auto bitsPerLane = elemBits * chunkSize;
434  if (chunkSize > 1 && bitsPerLane % 32) {
435  // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
436  // For 32-bit data, the hardware can support larger larger chunk size. So
437  // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
438  // But this requires the total size is 32 bit aligned to make the
439  // optimization work.
440  return emitOpError(
441  "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
442  }
443 
444  auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
445  if (elemBits * tdescTy.getNumElements() > lscConstraints)
446  return emitOpError("total access size (simd_lanes * chunk_size * "
447  "sizeof(elemTy)) is upto 512 bytes.");
448 
449  SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
450  if (chunkSize != 1)
451  shape.push_back(chunkSize);
452 
453  auto tdescShape = getShapeOf(tdescTy);
454  if (shape != tdescShape)
455  return emitOpError("Incorrect TensorDesc shape. ")
456  << "Expected is " << makeString(shape) << "\n";
457 
458  return success();
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // XeGPU_PrefetchOp
463 //===----------------------------------------------------------------------===//
464 LogicalResult PrefetchOp::verify() {
465  auto tdescTy = getTensorDescType();
466  if (!tdescTy.isScattered())
467  return emitOpError("Expects a scattered TensorDesc.\n");
468 
469  if (!isReadHintOrNone(getL1HintAttr()))
470  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
471 
472  if (!isReadHintOrNone(getL2HintAttr()))
473  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
474 
475  if (!isReadHintOrNone(getL3HintAttr()))
476  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
477 
478  return success();
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // XeGPU_LoadGatherOp
483 //===----------------------------------------------------------------------===//
484 LogicalResult LoadGatherOp::verify() {
485  auto tdescTy = getTensorDescType();
486  auto maskTy = getMaskType();
487  auto valueTy = getValueType();
488 
489  if (!tdescTy.isScattered())
490  return emitOpError("Expects a scattered TensorDesc.\n");
491 
492  if (!isReadHintOrNone(getL1HintAttr()))
493  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494 
495  if (!isReadHintOrNone(getL2HintAttr()))
496  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497 
498  if (!isReadHintOrNone(getL3HintAttr()))
499  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500 
501  auto tdescElemTy = tdescTy.getElementType();
502  auto valueElemTy = getElementType();
503  if (tdescElemTy != valueElemTy)
504  return emitOpError(
505  "Value should have the same element type as TensorDesc.");
506 
507  auto maskShape = getShapeOf(maskTy);
508  auto valueShape = getShapeOf(valueTy);
509  auto tdescShape = getShapeOf(tdescTy);
510 
511  if (tdescShape[0] != maskShape[0])
512  return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
513 
514  if (tdescTy.getRank() == 2) {
515  if (!getTransposeAttr())
516  return emitOpError("load of rank-2 tensor has to be transposed.");
517  transpose({1, 0}, tdescShape);
518  }
519 
520  if (auto sgMap = tdescTy.getSGMapAttr()) {
521  auto valueVecTy = cast<VectorType>(valueTy);
522  const int32_t wiData =
523  sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
524  // All represent the same concept: a number of row elements to store.
525  if (valueVecTy.getNumElements() != wiData ||
526  valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
527  return emitOpError("Chunk size, vector size and wi_data must match.");
528  }
529  // Work-item's slice (i.e., vector shape to load) is [1] or [1, chunk_size].
530  tdescShape[tdescTy.getRank() - 1] = 1;
531  }
532 
533  if (valueShape != tdescShape)
534  return emitOpError("Unexpected result shape")
535  << "(Expected shape: " << makeString(tdescShape)
536  << ", Given shape: " << makeString(valueShape) << ").\n";
537 
538  return success();
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // XeGPU_StoreScatterOp
543 //===----------------------------------------------------------------------===//
544 LogicalResult StoreScatterOp::verify() {
545  auto tdescTy = getTensorDescType();
546  if (!tdescTy.isScattered())
547  return emitOpError("Expects a scattered TensorDesc.\n");
548 
549  if (!isWriteHintOrNone(getL1HintAttr()))
550  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
551 
552  if (!isWriteHintOrNone(getL2HintAttr()))
553  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
554 
555  if (!isWriteHintOrNone(getL3HintAttr()))
556  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
557 
558  auto maskTy = getMaskType();
559  auto valueTy = getValueType();
560  auto maskShape = getShapeOf(maskTy);
561  auto tdescShape = getShapeOf(tdescTy);
562  auto valueShape = getShapeOf(valueTy);
563  if (tdescShape[0] != maskShape[0])
564  return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
565 
566  if (tdescTy.getRank() == 2) {
567  if (!getTransposeAttr())
568  return emitOpError("Store of a rank-2 tensor has to be transposed.");
569  transpose({1, 0}, tdescShape);
570  }
571 
572  if (auto sgMap = tdescTy.getSGMapAttr()) {
573  auto valueVecTy = cast<VectorType>(valueTy);
574  const int32_t wiData =
575  sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
576  // All represent the same concept: a number of row elements to store.
577  if (valueVecTy.getNumElements() != wiData ||
578  valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
579  return emitOpError("Chunk size, vector size and wi_data must match.");
580  }
581  // Work-item's slice (i.e., vector to store) is [1] or [1, chunk_size].
582  tdescShape[tdescTy.getRank() - 1] = 1;
583  }
584 
585  if (valueShape != tdescShape)
586  return emitOpError("Unexpected value shape")
587  << "(Expected shape: " << makeString(tdescShape)
588  << ", Given shape: " << makeString(valueShape) << ").\n";
589 
590  return success();
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // XeGPU_UpdateOffsetOp
595 //===----------------------------------------------------------------------===//
596 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
597  mlir::Value tensorDesc,
599  auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
600  assert(tdescTy && "Expecting the source is a TensorDescType value.");
601  auto loc = tensorDesc.getLoc();
602  int64_t size = static_cast<int64_t>(offsets.size());
603  auto type = VectorType::get({size}, builder.getIndexType());
604  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
605  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
606  build(builder, state, tdescTy, tensorDesc, offset);
607 }
608 
609 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
610  Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
611  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
612  build(builder, state, tensorDesc, ofrs);
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // XeGPU_DpasOp
617 //===----------------------------------------------------------------------===//
618 LogicalResult DpasOp::verify() {
619  int64_t lhsRank = getLhsType().getRank();
620  int64_t rhsRank = getRhsType().getRank();
621 
622  if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
623  return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
624  "2D or 3D (packed) vector.");
625 
626  auto lhsShape = getLhsType().getShape();
627  auto rhsShape = getRhsType().getShape();
628  auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
629  if (bK != lhsShape[1])
630  return emitOpError("K-dimension mismatch.");
631 
632  return success();
633 }
634 
635 } // namespace xegpu
636 } // namespace mlir
637 
638 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
639 #define GET_OP_CLASSES
640 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
This class helps build Operations.
Definition: Builders.h:205
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
static std::string makeString(T array, bool breakline=false)
Definition: XeGPUOps.cpp:30
static int64_t getRankOf(Value val)
Definition: XeGPUOps.cpp:53
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:68
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:60
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
static bool isArgShapesValid(ArrayRef< int64_t > descShape, ArrayRef< int64_t > valShape, SGMapAttr sgMap)
Definition: XeGPUOps.cpp:82
static SmallVector< int64_t > getShapeOf(Type type)
Definition: XeGPUOps.cpp:44
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This represents an operation in an abstracted form, suitable for use with the builder APIs.