MLIR  22.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1 //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/TypeUtilities.h"
18 
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "xegpu"
22 
23 namespace mlir {
24 namespace xegpu {
25 
26 bool isSharedMemory(const MemRefType &memrefTy) {
27  Attribute attr = memrefTy.getMemorySpace();
28  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29  return intAttr.getInt() == 3;
30  if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31  return memrefSpace.getValue() == MemorySpace::SLM;
32  if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33  return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
35 }
36 
37 template <typename T>
38 static std::string makeString(T array, bool breakline = false) {
39  std::string buf;
40  buf.clear();
41  llvm::raw_string_ostream os(buf);
42  os << "[";
43  for (size_t i = 1; i < array.size(); i++) {
44  os << array[i - 1] << ", ";
45  if (breakline)
46  os << "\n\t\t";
47  }
48  os << array.back() << "]";
49  return buf;
50 }
51 
54  if (auto ty = llvm::dyn_cast<ShapedType>(type))
55  shape = SmallVector<int64_t>(ty.getShape());
56  else
57  shape.push_back(1);
58  return shape;
59 }
60 
61 static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62  if (!attr)
63  return true;
64  auto kind = attr.getValue();
65  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66  kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67 }
68 
69 static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70  if (!attr)
71  return true;
72  auto kind = attr.getValue();
73  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74  kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75 }
76 
77 static LogicalResult
78 isValidGatherScatterParams(Type maskTy, VectorType valueTy,
79  TensorDescType tdescTy,
81 
82  if (!tdescTy.isScattered())
83  return emitError() << "Expects a scattered TensorDesc.";
84 
85  auto chunkSize = tdescTy.getChunkSizeAsInt();
86  if (!valueTy) {
87  if (chunkSize > 1)
88  return emitError() << "Expecting chunk size == 1 for scalar result";
89  if (dyn_cast<VectorType>(maskTy))
90  return emitError() << "Expecting a vector type result.";
91  return success();
92  }
93 
94  auto maskShape = getShapeOf(maskTy);
95  auto valueShape = getShapeOf(valueTy);
96  auto tdescShape = getShapeOf(tdescTy);
97 
98  if (valueTy.getElementType() != tdescTy.getElementType())
99  return emitError()
100  << "Value should have the same element type as TensorDesc.";
101 
102  llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
103  if (chunkSize > 1)
104  expectedMaskShape.pop_back();
105  if (expectedMaskShape != maskShape)
106  return emitError()
107  << "Mask should match TensorDesc except the chunk size dim.";
108 
109  // a valid shape for SIMT case
110  if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111  if (tdescTy.getLayoutAttr())
112  return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
113  return success();
114  }
115 
116  if (tdescShape != valueShape)
117  return emitError() << "Value shape " << makeString(valueShape)
118  << " is neither a valid distribution for SIMT nor "
119  "consistent with the tensor descriptor for SIMD "
120  << tdescTy;
121  return success();
122 }
123 
124 static LogicalResult
126  VectorType valueTy, int64_t chunkSize,
128 
129  auto maskVecTy = dyn_cast<VectorType>(maskTy);
130  auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
131  if (!valueTy) {
132  if (chunkSize > 1)
133  return emitError() << "Expecting chunk size == 1 for scalar result";
134  if (maskVecTy || offsetsVecTy)
135  return emitError() << "Expecting scalar mask and offsets.";
136  else if (maskVecTy && offsetsVecTy)
137  return emitError() << "Expecting a vector type result.";
138  return success();
139  }
140 
141  auto valueSize = valueTy.getNumElements();
142  // SIMT mode with scalar mask and offsets.
143  if (!maskVecTy && !offsetsVecTy) {
144  if (valueSize != chunkSize)
145  return emitError() << "value elements must match chunk size "
146  << chunkSize;
147  return success();
148  }
149  auto maskShape = getShapeOf(maskTy);
150  auto valueShape = getShapeOf(valueTy);
151 
152  if (!maskVecTy)
153  return emitError() << "Expecting a vector type mask.";
154  int64_t maskSize = maskVecTy.getNumElements();
155 
156  if (chunkSize > 1) {
157  if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158  return emitError() << "value elements must match chunk size "
159  << chunkSize;
160  } else {
161  if (valueSize != maskSize)
162  return emitError()
163  << "Mask should match value except the chunk size dim.";
164  }
165  llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
166  if (maskSize == 1)
167  return success();
168  if (chunkSize > 1)
169  expectedMaskShape.pop_back();
170  if (expectedMaskShape != maskShape)
171  return emitError() << "Mask should match value except the chunk size dim.";
172 
173  return success();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // XeGPU_CreateNdDescOp
178 //===----------------------------------------------------------------------===//
179 
180 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
181  Type tdesc, TypedValue<MemRefType> source) {
182  [[maybe_unused]] auto ty = source.getType();
183  assert(ty.hasStaticShape() && "expecting a memref with static shape");
184 
185  build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
186  ValueRange({}) /* empty dynamic shape */,
187  ValueRange({}) /* empty dynamic strides */,
188  DenseI64ArrayAttr({}) /* const offsets */,
189  DenseI64ArrayAttr({}) /* empty const shape*/,
190  DenseI64ArrayAttr({}) /* empty const strides*/);
191 }
192 
193 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
194  Type tdesc, Value source,
197  Type srcTy = source.getType();
198  assert((isa<IntegerType, MemRefType>(srcTy)) &&
199  "Source has to be either int or memref.");
200 
201  llvm::SmallVector<Value> dynamicShape;
202  llvm::SmallVector<Value> dynamicStrides;
203 
204  llvm::SmallVector<int64_t> staticShape;
205  llvm::SmallVector<int64_t> staticStrides;
206 
207  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
208  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
209 
210  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
211  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
212 
213  if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
214  auto memrefShape = memrefTy.getShape();
215  auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
216 
217  // if shape and strides are from Memref, we don't need attributes for them
218  // to keep the IR print clean.
219  if (staticShape == memrefShape && staticStrides == memrefStrides) {
220  staticShapeAttr = DenseI64ArrayAttr();
221  staticStridesAttr = DenseI64ArrayAttr();
222  }
223  }
224 
225  build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
226  dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
227  staticStridesAttr);
228 }
229 
230 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
231  Type tdesc, TypedValue<MemRefType> source,
233  [[maybe_unused]] auto ty = source.getType();
234  assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
235 
236  llvm::SmallVector<int64_t> staticOffsets;
237  llvm::SmallVector<Value> dynamicOffsets;
238  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
239 
240  build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
241  ValueRange({}) /* empty dynamic shape */,
242  ValueRange({}) /* empty dynamic strides */,
243  builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
244  {} /* empty const shape*/, {} /* empty const strides*/);
245 }
246 
247 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
248  Type tdesc, Value source,
252  assert(shape.size() && offsets.size() && strides.size() &&
253  shape.size() == strides.size() && shape.size() == offsets.size());
254 
255  Type srcTy = source.getType();
256  assert((isa<IntegerType, MemRefType>(srcTy)) &&
257  "Source has to be either int or memref.");
258 
259  llvm::SmallVector<Value> dynamicOffsets;
260  llvm::SmallVector<Value> dynamicShape;
261  llvm::SmallVector<Value> dynamicStrides;
262 
263  llvm::SmallVector<int64_t> staticOffsets;
264  llvm::SmallVector<int64_t> staticShape;
265  llvm::SmallVector<int64_t> staticStrides;
266 
267  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
268  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
269  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
270 
271  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
272  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
273  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
274 
275  if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
276  auto memrefShape = memrefTy.getShape();
277  auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
278 
279  // if shape and strides are from Memref, we don't need attributes for them
280  // to keep the IR print clean.
281  if (staticShape == memrefShape && staticStrides == memrefStrides) {
282  staticShapeAttr = DenseI64ArrayAttr();
283  staticStridesAttr = DenseI64ArrayAttr();
284  }
285  }
286 
287  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
288  dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
289 }
290 
291 LogicalResult CreateNdDescOp::verify() {
292  size_t rank = getMixedSizes().size();
293  bool invalidRank = rank != getMixedStrides().size();
294  bool invalidElemTy = false;
295 
296  // Memory space of created TensorDesc should match with the source.
297  // Both source and TensorDesc are considered for global memory by default,
298  // if the memory scope attr is not specified. If source is an integer,
299  // it is considered as ptr to global memory.
300  auto srcMemorySpace = getSourceMemorySpace();
301  auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
302  if (srcMemorySpace != tdescMemorySpace)
303  return emitOpError("Memory space mismatch.")
304  << " Source: " << srcMemorySpace
305  << ", TensorDesc: " << tdescMemorySpace;
306 
307  if (size_t offsetRank = getMixedOffsets().size())
308  invalidRank |= (offsetRank != rank);
309 
310  // check source type matches the rank if it is a memref.
311  // It also should have the same ElementType as TensorDesc.
312  if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
313  invalidElemTy |= memrefTy.getElementType() != getElementType();
314 
315  if (llvm::isa<IntegerType>(getSourceType())) {
316  // strides and shape must present for integer source.
317  if (getMixedStrides().empty() || getMixedSizes().empty())
318  return emitOpError("expecting strides and shape to be present for "
319  "integer source.");
320  }
321 
322  if (invalidRank)
323  return emitOpError(
324  "Expecting the rank of shape, strides, offsets, and source (if source "
325  "is a memref) should match with each other.");
326 
327  // check result TensorDesc rank
328  if (getType().getRank() > (int64_t)rank)
329  return emitOpError(
330  "Expecting the TensorDesc rank is not greater than the "
331  "ranks of shape, strides, offsets or the memref source.");
332 
333  if (invalidElemTy)
334  return emitOpError("TensorDesc should have the same element "
335  "type with the source if it is a memref.\n");
336 
337  if (getType().isScattered())
338  return emitOpError("Expects a non-scattered TensorDesc.\n");
339 
340  return success();
341 }
342 
344  OpAsmParser &parser,
346  DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
348 
349  SmallVector<int64_t, 4> integerVals;
350  auto parseIntegerOrValue = [&]() {
352  auto res = parser.parseOptionalOperand(operand);
353 
354  if (res.has_value() && succeeded(res.value())) {
355  values.push_back(operand);
356  integerVals.push_back(ShapedType::kDynamic);
357  if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
358  return failure();
359  } else {
360  int64_t integer;
361  if (failed(parser.parseInteger(integer)))
362  return failure();
363  integerVals.push_back(integer);
364  }
365  return success();
366  };
367 
368  // If the optional values are given there must be left bracket
369  if (parser.parseOptionalLSquare().succeeded()) {
370  if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
371  parser.parseRSquare())
372  return parser.emitError(parser.getNameLoc())
373  << "expected a list of SSA values or integers";
374  integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
375  return success();
376  }
377 
378  return success();
379 }
380 
382  OperandRange values,
383  DenseI64ArrayAttr integers) {
384  if (!integers || integers.empty())
385  return;
386  printDynamicIndexList(printer, op, values, integers,
387  /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
388 }
389 //===----------------------------------------------------------------------===//
390 // XeGPU_PrefetchNdOp
391 //===----------------------------------------------------------------------===//
392 
393 void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
394  Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
395  xegpu::CachePolicyAttr l2_hint,
396  xegpu::CachePolicyAttr l3_hint) {
397 
398  return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
399  l1_hint, l2_hint, l3_hint);
400 }
401 
402 void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
403  Value tensorDesc, ArrayRef<OpFoldResult> offsets,
404  xegpu::CachePolicyAttr l1_hint,
405  xegpu::CachePolicyAttr l2_hint,
406  xegpu::CachePolicyAttr l3_hint) {
407  SmallVector<Value> dynamicOffsets;
408  SmallVector<int64_t> staticOffsets;
409  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
410 
411  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
412 
413  build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
414  l2_hint, l3_hint);
415 }
416 
417 LogicalResult PrefetchNdOp::verify() {
418  auto tdescTy = getTensorDescType();
419  if (tdescTy.isScattered())
420  return emitOpError("Expects a non-scattered TensorDesc.\n");
421 
422  if (!isReadHintOrNone(getL1HintAttr()))
423  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
424 
425  if (!isReadHintOrNone(getL2HintAttr()))
426  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
427 
428  if (!isReadHintOrNone(getL3HintAttr()))
429  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
430 
431  int64_t tDescRank = tdescTy.getRank();
432  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
433  int64_t constOffsetSize =
434  getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
435  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
436  ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
437  return emitOpError(
438  "Mismatched ranks between offsets and tensor descriptor");
439 
440  return success();
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // XeGPU_LoadNdOp
445 //===----------------------------------------------------------------------===//
446 
447 void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
448  Value tensorDesc, UnitAttr packed,
449  DenseI64ArrayAttr transpose,
450  xegpu::CachePolicyAttr l1_hint,
451  xegpu::CachePolicyAttr l2_hint,
452  xegpu::CachePolicyAttr l3_hint) {
453 
454  return build(builder, state, retType, tensorDesc, ValueRange(),
455  DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
456  l3_hint);
457 }
458 
459 void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
460  Value tensorDesc, ArrayRef<OpFoldResult> offsets,
461  UnitAttr packed, DenseI64ArrayAttr transpose,
462  xegpu::CachePolicyAttr l1_hint,
463  xegpu::CachePolicyAttr l2_hint,
464  xegpu::CachePolicyAttr l3_hint) {
465  SmallVector<Value> dynamicOffsets;
466  SmallVector<int64_t> staticOffsets;
467  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
468 
469  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
470 
471  build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
472  packed, transpose, l1_hint, l2_hint, l3_hint);
473 }
474 
475 LogicalResult LoadNdOp::verify() {
476  auto tdescTy = getTensorDescType();
477  auto valueTy = getType();
478 
479  if (tdescTy.isScattered())
480  return emitOpError("Expects a non-scattered TensorDesc.\n");
481 
482  if (tdescTy.getRank() > 2)
483  return emitOpError("Expects a 1D or 2D TensorDesc.\n");
484 
485  if (!valueTy)
486  return emitOpError("Invalid result, it should be a VectorType.\n");
487 
488  if (!isReadHintOrNone(getL1HintAttr()))
489  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
490 
491  if (!isReadHintOrNone(getL2HintAttr()))
492  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
493 
494  if (!isReadHintOrNone(getL3HintAttr()))
495  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
496 
497  int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
498  int valueElems = valueTy.getNumElements();
499 
500  // If the result vector is 1D and has less elements than the tensor
501  // descriptor, it is supposed to be a SIMT op. The layout attribute in
502  // tensor_desc is not needed.
503  if (valueElems < tdescElems && valueTy.getRank() == 1) {
504  // SIMT mode doesn't need LayoutAttr.
505  if (tdescTy.getLayoutAttr())
506  return emitOpError()
507  << "TensorDesc doesn't need LayoutAttr for SIMT code";
508 
509  // For SIMT code, the load is evenly distributed across all lanes in a
510  // subgroup. Since subgroup size is arch dependent, we only check even
511  // distribution here.
512  if (tdescElems % valueElems)
513  return emitOpError()
514  << "Result shape " << makeString(getShapeOf(valueTy))
515  << " is not a valid distribution for tensor descriptor "
516  << tdescTy;
517 
518  return success();
519  }
520 
521  // Check SIMD mode.
522  auto tdescShape = getShapeOf(tdescTy);
523  auto valueShape = getShapeOf(valueTy);
524 
525  if (getTranspose()) {
526  auto trans = getTranspose().value();
527  // Make sure the transpose value is valid, and apply it
528  if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
529  tdescShape = applyPermutation(tdescShape, trans);
530  else
531  mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
532  }
533 
534  if (getPacked()) {
535  if (tdescTy.getRank() == 2) {
536  const int axis = 0;
537  auto vnni_factor = valueShape.back();
538  tdescShape[axis] /= vnni_factor;
539  tdescShape.push_back(vnni_factor);
540  } else {
541  mlir::emitWarning(getLoc())
542  << "Invalid Packed Attr. It is ignored (available for 2D "
543  "TensorDesc only).";
544  }
545  }
546 
547  auto array_len = tdescTy.getArrayLength();
548  if (array_len > 1)
549  tdescShape.insert(tdescShape.begin(), array_len);
550 
551  if (tdescShape != valueShape)
552  return emitOpError() << "Result shape " << makeString(valueShape)
553  << " is not consistent with tensor descriptor "
554  << tdescTy;
555 
556  int64_t tDescRank = tdescTy.getRank();
557  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
558  int64_t constOffsetSize =
559  getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
560  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
561  ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
562  return emitOpError(
563  "Mismatched ranks between offsets and tensor descriptor");
564 
565  return success();
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // XeGPU_StoreNdOp
570 //===----------------------------------------------------------------------===//
571 
572 void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
573  Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
574  xegpu::CachePolicyAttr l2_hint,
575  xegpu::CachePolicyAttr l3_hint) {
576 
577  return build(builder, state, value, tensorDesc, ValueRange(),
578  DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
579 }
580 
581 void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
582  Value tensorDesc, ArrayRef<OpFoldResult> offsets,
583  xegpu::CachePolicyAttr l1_hint,
584  xegpu::CachePolicyAttr l2_hint,
585  xegpu::CachePolicyAttr l3_hint) {
586  SmallVector<Value> dynamicOffsets;
587  SmallVector<int64_t> staticOffsets;
588  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
589 
590  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
591 
592  build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
593  l1_hint, l2_hint, l3_hint);
594 }
595 
596 LogicalResult StoreNdOp::verify() {
597  auto dstTy = getTensorDescType(); // Tile
598  auto valTy = getValueType(); // Vector
599 
600  if (dstTy.isScattered())
601  return emitOpError("Expects a non-scattered TensorDesc.\n");
602 
603  if (dstTy.getRank() > 2)
604  return emitOpError("Expects a 1D or 2D TensorDesc.\n");
605 
606  if (!valTy)
607  return emitOpError("Expecting a VectorType result.\n");
608 
609  if (!isWriteHintOrNone(getL1HintAttr()))
610  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
611 
612  if (!isWriteHintOrNone(getL2HintAttr()))
613  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
614 
615  if (!isWriteHintOrNone(getL3HintAttr()))
616  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
617 
618  auto array_len = dstTy.getArrayLength();
619  if (array_len > 1)
620  return emitOpError("array length is not supported by store_nd.\n");
621 
622  auto tdescElems = dstTy.getNumElements();
623  auto valueElems = valTy.getNumElements();
624 
625  // Similar to LoadNdOp, if the value vector is 1D and has less elements than
626  // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
627  // in tensor_desc is not needed.
628  if (valTy.getRank() == 1 && valueElems < tdescElems) {
629  // SIMT mode doesn't need LayoutAttr.
630  if (dstTy.getLayoutAttr())
631  return emitOpError()
632  << "TensorDesc doesn't need LayoutAttr for SIMT code";
633 
634  if (tdescElems % valueElems)
635  return emitOpError()
636  << "Value shape " << makeString(getShapeOf(valTy))
637  << " is not a valid distribution for tensor descriptor " << dstTy;
638 
639  return success();
640  }
641 
642  // SIMD code should have the same shape as the tensor descriptor.
643  auto tdescShape = getShapeOf(dstTy);
644  auto valueShape = getShapeOf(valTy);
645  if (tdescShape != valueShape)
646  return emitOpError() << "Value shape " << makeString(valueShape)
647  << " is not consistent with tensor descriptor "
648  << dstTy;
649 
650  int64_t tDescRank = dstTy.getRank();
651  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
652  int64_t constOffsetSize =
653  getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
654  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
655  ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
656  return emitOpError(
657  "Mismatched ranks between offsets and tensor descriptor");
658 
659  return success();
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // XeGPU_UpdateNDOffsetOp
664 //===----------------------------------------------------------------------===//
665 LogicalResult UpdateNdOffsetOp::verify() {
666  auto ty = getTensorDescType();
667  if (ty.isScattered())
668  return emitOpError("Expects a non-scattered TensorDesc.\n");
669 
670  // number of offsets specified must match the rank of the tensor descriptor
671  if (ty.getRank() != (int64_t)getNumOffsets()) {
672  return emitOpError("Invalid number of offsets.");
673  }
674  return success();
675 }
676 
677 //===----------------------------------------------------------------------===//
678 // XeGPU_CreateDescOp
679 //===----------------------------------------------------------------------===//
680 
681 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
682  TensorDescType TensorDesc, Value source,
684  auto loc = source.getLoc();
685  int64_t size = static_cast<int64_t>(offsets.size());
686  auto type = VectorType::get(size, builder.getIndexType());
687  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
688  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
689  build(builder, state, TensorDesc, source, offset);
690 }
691 
692 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
693  TensorDescType TensorDesc, Value source,
694  llvm::ArrayRef<int64_t> offsets) {
695  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
696  build(builder, state, TensorDesc, source, ofrs);
697 }
698 
699 LogicalResult CreateDescOp::verify() {
700  auto tdescTy = getTensorDescType();
701 
702  if (!tdescTy.isScattered())
703  return emitOpError("Expects a scattered TensorDesc.\n");
704 
705  // Memory space of created TensorDesc should match with the source.
706  // Both source and TensorDesc are considered for global memory by default,
707  // if the memory scope attr is not specified. If source is an integer,
708  // it is considered as ptr to global memory.
709  auto srcMemorySpace = getSourceMemorySpace();
710  auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
711  if (srcMemorySpace != tdescMemorySpace)
712  return emitOpError("Memory space mismatch.")
713  << " Source: " << srcMemorySpace
714  << ", TensorDesc: " << tdescMemorySpace;
715 
716  // check total size
717  auto chunkSize = tdescTy.getChunkSizeAsInt();
718  SmallVector<int64_t> shape(getOffsetsType().getShape());
719  if (chunkSize != 1)
720  shape.push_back(chunkSize);
721 
722  auto tdescShape = getShapeOf(tdescTy);
723  if (shape != tdescShape)
724  return emitOpError("Incorrect TensorDesc shape. ")
725  << "Expected is " << makeString(shape) << "\n";
726 
727  return success();
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // XeGPU_PrefetchOp
732 //===----------------------------------------------------------------------===//
733 LogicalResult PrefetchOp::verify() {
734  auto tdescTy = getTensorDescType();
735 
736  if (!tdescTy && !getOffsets())
737  return emitOpError("Expects offsets.");
738 
739  if (tdescTy && getOffsets())
740  return emitOpError("offsets not allowed.");
741 
742  if (tdescTy && !tdescTy.isScattered())
743  return emitOpError("Expects a scattered TensorDesc.");
744 
745  if (!isReadHintOrNone(getL1HintAttr()))
746  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
747 
748  if (!isReadHintOrNone(getL2HintAttr()))
749  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
750 
751  if (!isReadHintOrNone(getL3HintAttr()))
752  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
753 
754  auto srcTy = getSourceType();
755  if (srcTy.isInteger() && !getOffsetAlignByteAttr())
756  return emitOpError("offset_align_byte is required with integer source.");
757 
758  if (getOffsetAlignByteAttr() && !srcTy.isInteger())
759  return emitOpError("offset_align_byte only allowed with integer source.");
760 
761  return success();
762 }
763 
764 void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
765  xegpu::CachePolicyAttr l1_hint,
766  xegpu::CachePolicyAttr l2_hint,
767  xegpu::CachePolicyAttr l3_hint) {
768  build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
769  IntegerAttr{});
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // XeGPU_LoadGatherOp
774 //===----------------------------------------------------------------------===//
775 LogicalResult LoadGatherOp::verify() {
776  auto tdescTy = getTensorDescType();
777  auto maskTy = getMaskType();
778  auto valueTy = getValueType();
779 
780  if (!tdescTy && !getOffsets())
781  return emitOpError("Expects offsets.");
782 
783  if (tdescTy && getOffsets())
784  return emitOpError("offsets not allowed.");
785 
786  if (tdescTy && !tdescTy.isScattered())
787  return emitOpError("Expects a scattered TensorDesc.");
788 
789  if (!isReadHintOrNone(getL1HintAttr()))
790  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
791 
792  if (!isReadHintOrNone(getL2HintAttr()))
793  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
794 
795  if (!isReadHintOrNone(getL3HintAttr()))
796  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
797 
798  if (tdescTy)
799  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
800  [&]() { return emitOpError(); });
801  auto srcTy = getSourceType();
802  uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
803  auto memTy = dyn_cast<MemRefType>(srcTy);
804 
805  if (memTy && (getElementType() != memTy.getElementType()))
806  return emitError() << "Value should have the same element type as MemRef.";
807 
808  auto offsetsTy = getOffsets().getType();
809  return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
810  [&]() { return emitOpError(); });
811 }
812 
813 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
814  Type valueType, Value source, Value mask,
815  xegpu::CachePolicyAttr l1_hint,
816  xegpu::CachePolicyAttr l2_hint,
817  xegpu::CachePolicyAttr l3_hint) {
818  build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
819  l1_hint, l2_hint, l3_hint);
820 }
821 
822 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
823  Type valueType, Value source,
824  ArrayRef<OpFoldResult> offsets, Value mask,
825  IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
826  xegpu::CachePolicyAttr l2_hint,
827  xegpu::CachePolicyAttr l3_hint) {
828  auto loc = source.getLoc();
829  int64_t size = static_cast<int64_t>(offsets.size());
830  auto type = VectorType::get(size, builder.getIndexType());
831  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
832  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
833 
834  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
835  l2_hint, l3_hint);
836 }
837 
838 //===----------------------------------------------------------------------===//
839 // XeGPU_StoreScatterOp
840 //===----------------------------------------------------------------------===//
841 LogicalResult StoreScatterOp::verify() {
842  auto tdescTy = getTensorDescType();
843  auto maskTy = getMaskType();
844  auto valueTy = getValueType();
845 
846  if (!tdescTy && !getOffsets())
847  return emitOpError("Expects offsets.");
848 
849  if (tdescTy && getOffsets())
850  return emitOpError("offsets not allowed.");
851 
852  if (tdescTy && !tdescTy.isScattered())
853  return emitOpError("Expects a scattered TensorDesc.");
854 
855  if (!isWriteHintOrNone(getL1HintAttr()))
856  return emitOpError("invalid l1_hint: ") << getL1HintAttr();
857 
858  if (!isWriteHintOrNone(getL2HintAttr()))
859  return emitOpError("invalid l2_hint: ") << getL2HintAttr();
860 
861  if (!isWriteHintOrNone(getL3HintAttr()))
862  return emitOpError("invalid l3_hint: ") << getL3HintAttr();
863 
864  if (tdescTy)
865  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
866  [&]() { return emitOpError(); });
867 
868  auto destTy = getDestType();
869  uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
870  auto memTy = dyn_cast<MemRefType>(destTy);
871 
872  if (memTy && (getElementType() != memTy.getElementType()))
873  return emitError() << "Value should have the same element type as MemRef.";
874 
875  auto offsetsTy = getOffsets().getType();
876  return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
877  [&]() { return emitOpError(); });
878 }
879 
880 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
881  Value value, Value dest, Value mask,
882  xegpu::CachePolicyAttr l1_hint,
883  xegpu::CachePolicyAttr l2_hint,
884  xegpu::CachePolicyAttr l3_hint) {
885  build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
886  l2_hint, l3_hint);
887 }
888 
889 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
890  Value value, Value dest,
891  ArrayRef<OpFoldResult> offsets, Value mask,
892  IntegerAttr chunk_size,
893  xegpu::CachePolicyAttr l1_hint,
894  xegpu::CachePolicyAttr l2_hint,
895  xegpu::CachePolicyAttr l3_hint) {
896  auto loc = dest.getLoc();
897  int64_t size = static_cast<int64_t>(offsets.size());
898  auto type = VectorType::get(size, builder.getIndexType());
899  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
900  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
901 
902  // Call the correct builder overload that does not expect result types.
903  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
904  l3_hint);
905 }
906 
907 //===----------------------------------------------------------------------===//
908 // XeGPU_UpdateOffsetOp
909 //===----------------------------------------------------------------------===//
910 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
911  mlir::Value tensorDesc,
913  auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
914  assert(tdescTy && "Expecting the source is a TensorDescType value.");
915  auto loc = tensorDesc.getLoc();
916  int64_t size = static_cast<int64_t>(offsets.size());
917  auto type = VectorType::get({size}, builder.getIndexType());
918  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
919  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
920  build(builder, state, tdescTy, tensorDesc, offset);
921 }
922 
923 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
924  Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
925  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
926  build(builder, state, tensorDesc, ofrs);
927 }
928 
929 LogicalResult UpdateOffsetOp::verify() {
930  auto tdescTy = getTensorDescType();
931  if (!tdescTy.isScattered())
932  return emitOpError("Expects a scattered TensorDesc.\n");
933 
934  SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
935  SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
936  if (tdescTy.getChunkSizeAsInt() > 1)
937  expectedOffsetShape.pop_back();
938 
939  if (expectedOffsetShape != offsetShape)
940  return emitOpError(
941  "Offsets should match TensorDesc except the chunk size dim.");
942 
943  return success();
944 }
945 
946 //===----------------------------------------------------------------------===//
947 // XeGPU_DpasOp
948 //===----------------------------------------------------------------------===//
949 LogicalResult DpasOp::verify() {
950  int64_t lhsRank = getLhsType().getRank();
951  int64_t rhsRank = getRhsType().getRank();
952  int64_t resRank = getResultType().getRank();
953  auto lhsShape = getLhsType().getShape();
954  auto rhsShape = getRhsType().getShape();
955  auto resShape = getResultType().getShape();
956 
957  if (getAcc() && getAcc().getType() != getResultType())
958  return emitOpError("Expecting the acc type to be the same as result.");
959 
960  // SIMT code: the size of the B operand has to be a multiple of 32 bits.
961  // It skips the semantic check since lack of architecture information.
962  // Users need to ensure the correctness.
963  if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
964  auto numElems = getRhsType().getNumElements();
965  auto elemTy = getRhsType().getElementType();
966  auto factor = 32 / elemTy.getIntOrFloatBitWidth();
967  if (numElems % factor != 0)
968  return emitOpError("Expecting B operand to be a multiple of 32 bits.");
969  return success();
970  }
971 
972  // SIMD code
973  if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
974  return emitOpError(
975  "expecting lhs and result to be a 2D vector, and rhs to be either "
976  "2D or 3D (packed) vector.");
977  auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
978  if (bK != lhsShape[1])
979  return emitOpError("K-dimension mismatch.");
980  if (lhsShape[0] != resShape[0])
981  return emitOpError("M-dimension mismatch.");
982  if (rhsShape[1] != resShape[1])
983  return emitOpError("N-dimension mismatch.");
984 
985  return success();
986 }
987 
988 //===----------------------------------------------------------------------===//
989 // XeGPU_ConvertLayoutOp
990 //===----------------------------------------------------------------------===//
991 LogicalResult ConvertLayoutOp::verify() {
992  auto srcLayout = getInputLayout();
993  auto resLayout = getTargetLayout();
994  if (!srcLayout)
995  return emitOpError("expected input layout.");
996  if (!resLayout)
997  return emitOpError("expected target layout.");
998 
999  // both input and target layouts should be WgLayout or SgLayout at the same
1000  // time.
1001  if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1002  (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1003  return emitOpError("expected input layout and target layout be WgLayout or "
1004  "SgLayout at the same time.");
1005 
1006  auto shape = getSource().getType().getShape();
1007  if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
1008  return emitOpError(
1009  "invalid input layout, data cannot be evenly distributed.");
1010 
1011  if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
1012  return emitOpError(
1013  "invalid target layout, data cannot be evenly distributed.");
1014 
1015  return mlir::success();
1016 }
1017 
1018 OpFoldResult ConvertLayoutOp::fold(FoldAdaptor adaptor) {
1019  if (getInputLayout() == getTargetLayout())
1020  return getSource();
1021  return {};
1022 }
1023 
1024 struct FoldConvertLayoutOp : public OpRewritePattern<xegpu::ConvertLayoutOp> {
1026  LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
1027  PatternRewriter &rewriter) const override {
1028  if (op.getInputLayout() == op.getTargetLayout()) {
1029  rewriter.replaceOp(op, op.getSource());
1030  return success();
1031  }
1032  return failure();
1033  }
1034 };
1035 
1036 void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1037  MLIRContext *context) {
1038  patterns.add<FoldConvertLayoutOp>(context);
1039 }
1040 
1041 //===----------------------------------------------------------------------===//
1042 // XeGPU_LoadMatrixOp
1043 //===----------------------------------------------------------------------===//
1044 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
1045  TypedValue<MemDescType> memDesc,
1047  DistributeLayoutAttr layout) {
1048  llvm::SmallVector<Value> dynamicOffsets;
1049  llvm::SmallVector<int64_t> staticOffsets;
1050  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1051  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1052  build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1053  layout);
1054 }
1055 
1056 LogicalResult LoadMatrixOp::verify() {
1057  VectorType resTy = getRes().getType();
1058  MemDescType mdescTy = getMemDesc().getType();
1059 
1060  if (mdescTy.getRank() != 2)
1061  return emitOpError("mem_desc must be 2D.");
1062 
1063  ArrayRef<int64_t> valueShape = resTy.getShape();
1064  ArrayRef<int64_t> mdescShape = mdescTy.getShape();
1065  if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
1066  [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
1067  return emitOpError("result shape must not exceed mem_desc shape.");
1068  return success();
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // XeGPU_StoreMatrixOp
1073 //===----------------------------------------------------------------------===//
1074 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
1075  TypedValue<MemDescType> memDesc,
1077  DistributeLayoutAttr layout) {
1078  llvm::SmallVector<Value> dynamicOffsets;
1079  llvm::SmallVector<int64_t> staticOffsets;
1080  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1081  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1082  build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1083  layout);
1084 }
1085 
1086 LogicalResult StoreMatrixOp::verify() {
1087  VectorType dataTy = getData().getType();
1088  MemDescType mdescTy = getMemDesc().getType();
1089 
1090  if (mdescTy.getRank() != 2)
1091  return emitOpError("mem_desc must be 2D.");
1092 
1093  ArrayRef<int64_t> dataShape = dataTy.getShape();
1094  ArrayRef<int64_t> mdescShape = mdescTy.getShape();
1095  if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
1096  [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
1097  return emitOpError("data shape must not exceed mem_desc shape.");
1098 
1099  return success();
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // XeGPU_MemDescSubviewOp
1104 //===----------------------------------------------------------------------===//
1105 
1106 void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
1107  Type resTy, Value src,
1108  llvm::ArrayRef<OpFoldResult> offsets) {
1109  llvm::SmallVector<Value> dynamicOffsets;
1110  llvm::SmallVector<int64_t> staticOffsets;
1111  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1112  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1113  build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
1114 }
1115 
1116 LogicalResult MemDescSubviewOp::verify() {
1117  MemDescType srcTy = getSrc().getType();
1118  MemDescType resTy = getRes().getType();
1119  ArrayRef<int64_t> srcShape = srcTy.getShape();
1120  ArrayRef<int64_t> resShape = resTy.getShape();
1121 
1122  if (srcTy.getRank() < resTy.getRank())
1123  return emitOpError("result rank must not exceed source rank.");
1124 
1125  if (llvm::any_of(
1126  llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
1127  [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
1128  return emitOpError("result shape must not exceed source shape.");
1129 
1130  if (srcTy.getStrides() != resTy.getStrides())
1131  return emitOpError("result must inherit the source strides.");
1132 
1133  return success();
1134 }
1135 
1136 } // namespace xegpu
1137 } // namespace mlir
1138 
1139 namespace mlir {
1140 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1141 } // namespace mlir
1142 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1143 #define GET_OP_CLASSES
1144 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static Type getValueType(Attribute attr)
Definition: SPIRVOps.cpp:773
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:207
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
@ Type
An inlay hint that for a type annotation.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
static std::string makeString(T array, bool breakline=false)
Definition: XeGPUOps.cpp:38
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition: XeGPUOps.cpp:125
ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Definition: XeGPUOps.cpp:343
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
Definition: XeGPUOps.cpp:381
bool isSharedMemory(const MemRefType &memrefTy)
Definition: XeGPUOps.cpp:26
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition: XeGPUOps.cpp:61
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Definition: XeGPUOps.cpp:78
static SmallVector< int64_t > getShapeOf(Type type)
Definition: XeGPUOps.cpp:52
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override
Definition: XeGPUOps.cpp:1026