MLIR 23.0.0git
XeGPUOps.cpp
Go to the documentation of this file.
1//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
18
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "xegpu"
22
23using namespace mlir;
24using namespace mlir::xegpu;
25
26static bool isSharedMemory(const MemRefType &memrefTy) {
27 Attribute attr = memrefTy.getMemorySpace();
28 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
29 return intAttr.getInt() == 3;
30 if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
31 return memrefSpace.getValue() == MemorySpace::SLM;
32 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
33 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
34 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
35}
36
37template <typename T>
38static std::string makeString(T array, bool breakline = false) {
39 std::string buf;
40 buf.clear();
41 llvm::raw_string_ostream os(buf);
42 os << "[";
43 for (size_t i = 1; i < array.size(); i++) {
44 os << array[i - 1] << ", ";
45 if (breakline)
46 os << "\n\t\t";
47 }
48 os << array.back() << "]";
49 return buf;
50}
51
54 if (auto ty = llvm::dyn_cast<ShapedType>(type))
55 shape = SmallVector<int64_t>(ty.getShape());
56 else
57 shape.push_back(1);
58 return shape;
59}
60
61static bool isReadHintOrNone(const CachePolicyAttr &attr) {
62 if (!attr)
63 return true;
64 auto kind = attr.getValue();
65 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
66 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
67}
68
69static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
70 if (!attr)
71 return true;
72 auto kind = attr.getValue();
73 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
74 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
75}
76
77static LogicalResult
78isValidGatherScatterParams(Type maskTy, VectorType valueTy,
79 TensorDescType tdescTy,
81
82 if (!tdescTy.isScattered())
83 return emitError() << "Expects a scattered TensorDesc.";
84
85 auto chunkSize = tdescTy.getChunkSizeAsInt();
86 if (!valueTy) {
87 if (chunkSize > 1)
88 return emitError() << "Expecting chunk size == 1 for scalar result";
89 if (dyn_cast<VectorType>(maskTy))
90 return emitError() << "Expecting a vector type result.";
91 return success();
92 }
93
94 auto maskShape = getShapeOf(maskTy);
95 auto valueShape = getShapeOf(valueTy);
96 auto tdescShape = getShapeOf(tdescTy);
97
98 if (valueTy.getElementType() != tdescTy.getElementType())
99 return emitError()
100 << "Value should have the same element type as TensorDesc.";
101
102 llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
103 if (chunkSize > 1)
104 expectedMaskShape.pop_back();
105 if (expectedMaskShape != maskShape)
106 return emitError()
107 << "Mask should match TensorDesc except the chunk size dim.";
108
109 // a valid shape for SIMT case
110 if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
111 if (tdescTy.getLayoutAttr())
112 return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
113 return success();
114 }
115
116 if (tdescShape != valueShape)
117 return emitError() << "Value shape " << makeString(valueShape)
118 << " is neither a valid distribution for SIMT nor "
119 "consistent with the tensor descriptor for SIMD "
120 << tdescTy;
121 return success();
122}
123
124static LogicalResult
126 VectorType valueTy, int64_t chunkSize,
128
129 auto maskVecTy = dyn_cast<VectorType>(maskTy);
130 auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy);
131 if (!valueTy) {
132 if (chunkSize > 1)
133 return emitError() << "Expecting chunk size == 1 for scalar result";
134 if (maskVecTy || offsetsVecTy)
135 return emitError() << "Expecting scalar mask and offsets.";
136 else if (maskVecTy && offsetsVecTy)
137 return emitError() << "Expecting a vector type result.";
138 return success();
139 }
140
141 auto valueSize = valueTy.getNumElements();
142 // SIMT mode with scalar mask and offsets.
143 if (!maskVecTy && !offsetsVecTy) {
144 if (valueSize != chunkSize)
145 return emitError() << "value elements must match chunk size "
146 << chunkSize;
147 return success();
148 }
149 auto maskShape = getShapeOf(maskTy);
150 auto valueShape = getShapeOf(valueTy);
151
152 if (!maskVecTy)
153 return emitError() << "Expecting a vector type mask.";
154 int64_t maskSize = maskVecTy.getNumElements();
155
156 if (chunkSize > 1) {
157 if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
158 return emitError() << "value elements must match chunk size "
159 << chunkSize;
160 } else {
161 if (valueSize != maskSize)
162 return emitError()
163 << "Mask should match value except the chunk size dim.";
164 }
165 llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
166 if (maskSize == 1)
167 return success();
168 if (chunkSize > 1)
169 expectedMaskShape.pop_back();
170 if (expectedMaskShape != maskShape)
171 return emitError() << "Mask should match value except the chunk size dim.";
172
173 return success();
174}
175
176LogicalResult
177IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
178 UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
180
181 if (!dataTy) {
182 if (subgroup_block_io)
183 return emitError() << "subgroup_block_io "
184 "are only allowed when result is a VectorType.";
185 else
186 return success();
187 }
188
189 if (mdescTy.getRank() < 2)
190 return emitError() << "mem_desc must be 2D or greater.";
191
192 ArrayRef<int64_t> dataShape = dataTy.getShape();
193 ArrayRef<int64_t> mdescShape = mdescTy.getShape();
194
195 SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
196 ArrayAttr strideAttr = mdescTy.getStrideAttr();
197 SmallVector<int64_t> strides;
198 for (Attribute attr : strideAttr.getValue()) {
199 strides.push_back(cast<IntegerAttr>(attr).getInt());
200 }
201 if (subgroup_block_io && layout) {
202 auto laneData = layout.getEffectiveLaneDataAsInt();
203 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
204 if (!laneData.empty()) {
205 bool isLaneDataContiguous =
206 std::all_of(laneData.begin(), std::prev(laneData.end()),
207 [](int x) { return x == 1; });
208 if (!isLaneDataContiguous)
209 return emitError() << "With subgroup_block_io, accessed data must be "
210 "contiguous and coalesced.";
211 for (size_t i = 0; i < laneData.size(); ++i) {
212 if (laneLayout[i] != blockShape[i])
213 return emitError() << "With subgroup_block_io, the block shape must "
214 "match the lane layout.";
215 if (laneLayout[i] != 1 && strides[i] != 1)
216 return emitError() << "With subgroup_block_io, the distributed "
217 "dimensions must be contiguous.";
218 }
219 }
220 }
221 if (dataShape.size() == 2) {
222 if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
223 [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
224 return emitError() << "data shape must not exceed mem_desc shape.";
225 } else {
226 // if the subgroup_block_io attribute is set, mdescTy must have block
227 // attribute
228 if (subgroup_block_io && !blockShape.size())
229 return emitError() << "mem_desc must have block attribute when "
230 "subgroup_block_io is set.";
231 // if the subgroup_block_io attribute is set, the memdesc should be row
232 // major
233 if (subgroup_block_io && mdescTy.isColMajor())
234 return emitError() << "mem_desc should be row major when "
235 "subgroup_block_io is set.";
236 }
237
238 return success();
239}
240
241//===----------------------------------------------------------------------===//
242// XeGPU_CreateNdDescOp
243//===----------------------------------------------------------------------===//
244
245void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
246 Type tdesc, TypedValue<MemRefType> source) {
247 [[maybe_unused]] auto ty = source.getType();
248 assert(ty.hasStaticShape() && "expecting a memref with static shape");
249
250 build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
251 ValueRange({}) /* empty dynamic shape */,
252 ValueRange({}) /* empty dynamic strides */,
253 DenseI64ArrayAttr({}) /* const offsets */,
254 DenseI64ArrayAttr({}) /* empty const shape*/,
255 DenseI64ArrayAttr({}) /* empty const strides*/);
256}
257
258void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
259 Type tdesc, Value source,
262 Type srcTy = source.getType();
263 assert((isa<IntegerType, MemRefType>(srcTy)) &&
264 "Source has to be either int or memref.");
265
266 llvm::SmallVector<Value> dynamicShape;
267 llvm::SmallVector<Value> dynamicStrides;
268
269 llvm::SmallVector<int64_t> staticShape;
270 llvm::SmallVector<int64_t> staticStrides;
271
272 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
273 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
274
275 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
276 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
277
278 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
279 auto memrefShape = memrefTy.getShape();
280 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
281
282 // if shape and strides are from Memref, we don't need attributes for them
283 // to keep the IR print clean (only do so for full-static case, otherwise
284 // printer would fail trying to print empty array-attr).
285 if (staticShape == memrefShape && staticStrides == memrefStrides &&
286 dynamicShape.empty() && dynamicStrides.empty()) {
287 staticShapeAttr = DenseI64ArrayAttr();
288 staticStridesAttr = DenseI64ArrayAttr();
289 }
290 }
291
292 build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
293 dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
294 staticStridesAttr);
295}
296
297void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
298 Type tdesc, TypedValue<MemRefType> source,
300 [[maybe_unused]] auto ty = source.getType();
301 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
302
303 llvm::SmallVector<int64_t> staticOffsets;
304 llvm::SmallVector<Value> dynamicOffsets;
305 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
306
307 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
308 ValueRange({}) /* empty dynamic shape */,
309 ValueRange({}) /* empty dynamic strides */,
310 builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
311 {} /* empty const shape*/, {} /* empty const strides*/);
312}
313
314void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
315 Type tdesc, Value source,
319 assert(!shape.empty() && !offsets.empty() && !strides.empty() &&
320 shape.size() == strides.size() && shape.size() == offsets.size());
321
322 Type srcTy = source.getType();
323 assert((isa<IntegerType, MemRefType>(srcTy)) &&
324 "Source has to be either int or memref.");
325
326 llvm::SmallVector<Value> dynamicOffsets;
327 llvm::SmallVector<Value> dynamicShape;
328 llvm::SmallVector<Value> dynamicStrides;
329
330 llvm::SmallVector<int64_t> staticOffsets;
331 llvm::SmallVector<int64_t> staticShape;
332 llvm::SmallVector<int64_t> staticStrides;
333
334 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
335 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
336 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
337
338 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
339 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
340 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
341
342 if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
343 auto memrefShape = memrefTy.getShape();
344 auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
345
346 // if shape and strides are from Memref, we don't need attributes for them
347 // to keep the IR print clean (only do so for full-static case, otherwise
348 // printer would fail trying to print empty array-attr).
349 if (staticShape == memrefShape && staticStrides == memrefStrides &&
350 dynamicShape.empty() && dynamicStrides.empty()) {
351 staticShapeAttr = DenseI64ArrayAttr();
352 staticStridesAttr = DenseI64ArrayAttr();
353 }
354 }
355
356 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
357 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
358}
359
360LogicalResult CreateNdDescOp::verify() {
361 size_t rank = getMixedSizes().size();
362 bool invalidRank = rank != getMixedStrides().size();
363 bool invalidElemTy = false;
364
365 // Memory space of created TensorDesc should match with the source.
366 // Both source and TensorDesc are considered for global memory by default,
367 // if the memory scope attr is not specified. If source is an integer,
368 // it is considered as ptr to global memory.
369 auto srcMemorySpace = getSourceMemorySpace();
370 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
371 if (srcMemorySpace != tdescMemorySpace)
372 return emitOpError("Memory space mismatch.")
373 << " Source: " << srcMemorySpace
374 << ", TensorDesc: " << tdescMemorySpace;
375
376 if (size_t offsetRank = getMixedOffsets().size())
377 invalidRank |= (offsetRank != rank);
378
379 // check source type matches the rank if it is a memref.
380 // It also should have the same ElementType as TensorDesc.
381 if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
382 invalidElemTy |= memrefTy.getElementType() != getElementType();
383
384 if (llvm::isa<IntegerType>(getSourceType())) {
385 // strides and shape must present for integer source.
386 if (getMixedStrides().empty() || getMixedSizes().empty())
387 return emitOpError("expecting strides and shape to be present for "
388 "integer source.");
389 }
390
391 if (invalidRank)
392 return emitOpError(
393 "Expecting the rank of shape, strides, offsets, and source (if source "
394 "is a memref) should match with each other.");
395
396 // check result TensorDesc rank
397 if (getType().getRank() > (int64_t)rank)
398 return emitOpError(
399 "Expecting the TensorDesc rank is not greater than the "
400 "ranks of shape, strides, offsets or the memref source.");
401
402 if (invalidElemTy)
403 return emitOpError("TensorDesc should have the same element "
404 "type with the source if it is a memref.\n");
405
406 if (getType().isScattered())
407 return emitOpError("Expects a non-scattered TensorDesc.\n");
408
409 return success();
410}
411
413 OpAsmParser &parser,
415 DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
417
418 SmallVector<int64_t, 4> integerVals;
419 auto parseIntegerOrValue = [&]() {
421 auto res = parser.parseOptionalOperand(operand);
422
423 if (res.has_value() && succeeded(res.value())) {
424 values.push_back(operand);
425 integerVals.push_back(ShapedType::kDynamic);
426 if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
427 return failure();
428 } else {
429 int64_t integer;
430 if (failed(parser.parseInteger(integer)))
431 return failure();
432 integerVals.push_back(integer);
433 }
434 return success();
435 };
436
437 // If the optional values are given there must be left bracket
438 if (parser.parseOptionalLSquare().succeeded()) {
439 if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
440 parser.parseRSquare())
441 return parser.emitError(parser.getNameLoc())
442 << "expected a list of SSA values or integers";
443 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
444 return success();
445 }
446
447 return success();
448}
449
451 OperandRange values,
452 DenseI64ArrayAttr integers) {
453 if (!integers || integers.empty())
454 return;
455 printDynamicIndexList(printer, op, values, integers,
456 /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
457}
458//===----------------------------------------------------------------------===//
459// XeGPU_PrefetchNdOp
460//===----------------------------------------------------------------------===//
461
462void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
463 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
464 xegpu::CachePolicyAttr l2_hint,
465 xegpu::CachePolicyAttr l3_hint) {
466
467 return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
468 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
469}
470
471void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
472 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
473 xegpu::CachePolicyAttr l1_hint,
474 xegpu::CachePolicyAttr l2_hint,
475 xegpu::CachePolicyAttr l3_hint,
476 xegpu::DistributeLayoutAttr layout) {
477 SmallVector<Value> dynamicOffsets;
478 SmallVector<int64_t> staticOffsets;
479 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
480
481 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
482
483 build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
484 l2_hint, l3_hint, /*anchor_layout=*/layout);
485}
486
487LogicalResult PrefetchNdOp::verify() {
488 auto tdescTy = getTensorDescType();
489 if (tdescTy.isScattered())
490 return emitOpError("Expects a non-scattered TensorDesc.\n");
491
492 if (!isReadHintOrNone(getL1HintAttr()))
493 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494
495 if (!isReadHintOrNone(getL2HintAttr()))
496 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497
498 if (!isReadHintOrNone(getL3HintAttr()))
499 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500
501 int64_t tDescRank = tdescTy.getRank();
502 int64_t offsetSize = getMixedOffsets().size();
503 if (offsetSize != 0 && offsetSize != tDescRank)
504 return emitOpError(
505 "Mismatched ranks between offsets and tensor descriptor");
506
507 return success();
508}
509
510//===----------------------------------------------------------------------===//
511// XeGPU_LoadNdOp
512//===----------------------------------------------------------------------===//
513
514void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
515 Value tensorDesc, UnitAttr packed,
516 DenseI64ArrayAttr transpose,
517 xegpu::CachePolicyAttr l1_hint,
518 xegpu::CachePolicyAttr l2_hint,
519 xegpu::CachePolicyAttr l3_hint) {
520
521 return build(builder, state, retType, tensorDesc, ValueRange(),
522 DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
523 l3_hint, /*anchor_layout=*/nullptr);
524}
525
526void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
527 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
528 UnitAttr packed, DenseI64ArrayAttr transpose,
529 xegpu::CachePolicyAttr l1_hint,
530 xegpu::CachePolicyAttr l2_hint,
531 xegpu::CachePolicyAttr l3_hint,
532 xegpu::DistributeLayoutAttr layout) {
533 SmallVector<Value> dynamicOffsets;
534 SmallVector<int64_t> staticOffsets;
535 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
536
537 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
538
539 build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
540 packed, transpose, l1_hint, l2_hint, l3_hint,
541 /*anchor_layout=*/layout);
542}
543
544LogicalResult LoadNdOp::verify() {
545 auto tdescTy = getTensorDescType();
546 auto valueTy = getType();
547
548 if (tdescTy.isScattered())
549 return emitOpError("Expects a non-scattered TensorDesc.\n");
550
551 if (tdescTy.getRank() > 2)
552 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
553
554 if (!valueTy)
555 return emitOpError("Invalid result, it should be a VectorType.\n");
556
557 if (!isReadHintOrNone(getL1HintAttr()))
558 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
559
560 if (!isReadHintOrNone(getL2HintAttr()))
561 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
562
563 if (!isReadHintOrNone(getL3HintAttr()))
564 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
565
566 int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
567 int valueElems = valueTy.getNumElements();
568
569 // If the result vector is 1D and has less elements than the tensor
570 // descriptor, it is supposed to be a SIMT op. The layout attribute in
571 // tensor_desc is not needed.
572 if (valueElems < tdescElems && valueTy.getRank() == 1) {
573 // SIMT mode doesn't need LayoutAttr.
574 if (tdescTy.getLayoutAttr())
575 return emitOpError()
576 << "TensorDesc doesn't need LayoutAttr for SIMT code";
577
578 // For SIMT code, the load is evenly distributed across all lanes in a
579 // subgroup. Since subgroup size is arch dependent, we only check even
580 // distribution here.
581 if (tdescElems % valueElems)
582 return emitOpError()
583 << "Result shape " << makeString(getShapeOf(valueTy))
584 << " is not a valid distribution for tensor descriptor "
585 << tdescTy;
586
587 return success();
588 }
589
590 // Check SIMD mode.
591 auto tdescShape = getShapeOf(tdescTy);
592 auto valueShape = getShapeOf(valueTy);
593
594 if (getTranspose()) {
595 auto trans = getTranspose().value();
596 // Make sure the transpose value is valid, and apply it
597 if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
598 tdescShape = applyPermutation(tdescShape, trans);
599 else
600 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
601 }
602
603 if (getPacked()) {
604 if (tdescTy.getRank() == 2) {
605 const int axis = 0;
606 auto vnni_factor = valueShape.back();
607 tdescShape[axis] /= vnni_factor;
608 tdescShape.push_back(vnni_factor);
609 } else {
610 mlir::emitWarning(getLoc())
611 << "Invalid Packed Attr. It is ignored (available for 2D "
612 "TensorDesc only).";
613 }
614 }
615
616 auto array_len = tdescTy.getArrayLength();
617 if (array_len > 1)
618 tdescShape.insert(tdescShape.begin(), array_len);
619
620 if (tdescShape != valueShape)
621 return emitOpError() << "Result shape " << makeString(valueShape)
622 << " is not consistent with tensor descriptor "
623 << tdescTy;
624
625 int64_t tDescRank = tdescTy.getRank();
626 int64_t offsetSize = getMixedOffsets().size();
627 if (offsetSize != 0 && offsetSize != tDescRank)
628 return emitOpError(
629 "Mismatched ranks between offsets and tensor descriptor");
630
631 return success();
632}
633
634//===----------------------------------------------------------------------===//
635// XeGPU_StoreNdOp
636//===----------------------------------------------------------------------===//
637
638void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
639 Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
640 xegpu::CachePolicyAttr l2_hint,
641 xegpu::CachePolicyAttr l3_hint) {
642
643 return build(builder, state, value, tensorDesc, ValueRange(),
644 DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint,
645 /*anchor_layout=*/nullptr);
646}
647
648void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
649 Value tensorDesc, ArrayRef<OpFoldResult> offsets,
650 xegpu::CachePolicyAttr l1_hint,
651 xegpu::CachePolicyAttr l2_hint,
652 xegpu::CachePolicyAttr l3_hint,
653 xegpu::DistributeLayoutAttr layout) {
654 SmallVector<Value> dynamicOffsets;
655 SmallVector<int64_t> staticOffsets;
656 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
657
658 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
659
660 build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
661 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
662}
663
664LogicalResult StoreNdOp::verify() {
665 auto dstTy = getTensorDescType(); // Tile
666 auto valTy = getValueType(); // Vector
667
668 if (dstTy.isScattered())
669 return emitOpError("Expects a non-scattered TensorDesc.\n");
670
671 if (dstTy.getRank() > 2)
672 return emitOpError("Expects a 1D or 2D TensorDesc.\n");
673
674 if (!valTy)
675 return emitOpError("Expecting a VectorType result.\n");
676
677 if (!isWriteHintOrNone(getL1HintAttr()))
678 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
679
680 if (!isWriteHintOrNone(getL2HintAttr()))
681 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
682
683 if (!isWriteHintOrNone(getL3HintAttr()))
684 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
685
686 auto array_len = dstTy.getArrayLength();
687 if (array_len > 1)
688 return emitOpError("array length is not supported by store_nd.\n");
689
690 auto tdescElems = dstTy.getNumElements();
691 auto valueElems = valTy.getNumElements();
692
693 // Similar to LoadNdOp, if the value vector is 1D and has less elements than
694 // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
695 // in tensor_desc is not needed.
696 if (valTy.getRank() == 1 && valueElems < tdescElems) {
697 // SIMT mode doesn't need LayoutAttr.
698 if (dstTy.getLayoutAttr())
699 return emitOpError()
700 << "TensorDesc doesn't need LayoutAttr for SIMT code";
701
702 if (tdescElems % valueElems)
703 return emitOpError()
704 << "Value shape " << makeString(getShapeOf(valTy))
705 << " is not a valid distribution for tensor descriptor " << dstTy;
706
707 return success();
708 }
709
710 // SIMD code should have the same shape as the tensor descriptor.
711 auto tdescShape = getShapeOf(dstTy);
712 auto valueShape = getShapeOf(valTy);
713 if (tdescShape != valueShape)
714 return emitOpError() << "Value shape " << makeString(valueShape)
715 << " is not consistent with tensor descriptor "
716 << dstTy;
717
718 int64_t tDescRank = dstTy.getRank();
719 int64_t offsetSize = getMixedOffsets().size();
720 if (offsetSize != 0 && offsetSize != tDescRank)
721 return emitOpError(
722 "Mismatched ranks between offsets and tensor descriptor");
723
724 return success();
725}
726
727//===----------------------------------------------------------------------===//
728// XeGPU_UpdateNDOffsetOp
729//===----------------------------------------------------------------------===//
730LogicalResult UpdateNdOffsetOp::verify() {
731 auto ty = getTensorDescType();
732 if (ty.isScattered())
733 return emitOpError("Expects a non-scattered TensorDesc.\n");
734
735 // number of offsets specified must match the rank of the tensor descriptor
736 if (ty.getRank() != (int64_t)getNumOffsets()) {
737 return emitOpError("Invalid number of offsets.");
738 }
739 return success();
740}
741
742//===----------------------------------------------------------------------===//
743// XeGPU_CreateDescOp
744//===----------------------------------------------------------------------===//
745
746void CreateDescOp::build(OpBuilder &builder, OperationState &state,
747 TensorDescType TensorDesc, Value source,
749 auto loc = source.getLoc();
750 int64_t size = static_cast<int64_t>(offsets.size());
751 auto type = VectorType::get(size, builder.getIndexType());
752 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
753 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
754 build(builder, state, TensorDesc, source, offset);
755}
756
757void CreateDescOp::build(OpBuilder &builder, OperationState &state,
758 TensorDescType TensorDesc, Value source,
759 llvm::ArrayRef<int64_t> offsets) {
760 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
761 build(builder, state, TensorDesc, source, ofrs);
762}
763
764LogicalResult CreateDescOp::verify() {
765 auto tdescTy = getTensorDescType();
766
767 if (!tdescTy.isScattered())
768 return emitOpError("Expects a scattered TensorDesc.\n");
769
770 // Memory space of created TensorDesc should match with the source.
771 // Both source and TensorDesc are considered for global memory by default,
772 // if the memory scope attr is not specified. If source is an integer,
773 // it is considered as ptr to global memory.
774 auto srcMemorySpace = getSourceMemorySpace();
775 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
776 if (srcMemorySpace != tdescMemorySpace)
777 return emitOpError("Memory space mismatch.")
778 << " Source: " << srcMemorySpace
779 << ", TensorDesc: " << tdescMemorySpace;
780
781 // check total size
782 auto chunkSize = tdescTy.getChunkSizeAsInt();
783 SmallVector<int64_t> shape(getOffsetsType().getShape());
784 if (chunkSize != 1)
785 shape.push_back(chunkSize);
786
787 auto tdescShape = getShapeOf(tdescTy);
788 if (shape != tdescShape)
789 return emitOpError("Incorrect TensorDesc shape. ")
790 << "Expected is " << makeString(shape) << "\n";
791
792 return success();
793}
794
795//===----------------------------------------------------------------------===//
796// XeGPU_PrefetchOp
797//===----------------------------------------------------------------------===//
798LogicalResult PrefetchOp::verify() {
799 auto tdescTy = getTensorDescType();
800
801 if (!tdescTy && !getOffsets())
802 return emitOpError("Expects offsets.");
803
804 if (tdescTy && getOffsets())
805 return emitOpError("offsets not allowed.");
806
807 if (tdescTy && !tdescTy.isScattered())
808 return emitOpError("Expects a scattered TensorDesc.");
809
810 if (!isReadHintOrNone(getL1HintAttr()))
811 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
812
813 if (!isReadHintOrNone(getL2HintAttr()))
814 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
815
816 if (!isReadHintOrNone(getL3HintAttr()))
817 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
818
819 auto srcTy = getSourceType();
820 if (srcTy.isInteger() && !getOffsetAlignByteAttr())
821 return emitOpError("offset_align_byte is required with integer source.");
822
823 if (getOffsetAlignByteAttr() && !srcTy.isInteger())
824 return emitOpError("offset_align_byte only allowed with integer source.");
825
826 return success();
827}
828
829void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
830 xegpu::CachePolicyAttr l1_hint,
831 xegpu::CachePolicyAttr l2_hint,
832 xegpu::CachePolicyAttr l3_hint) {
833 build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint,
834 IntegerAttr{}, /*anchor_layout=*/nullptr);
835}
836
837//===----------------------------------------------------------------------===//
838// XeGPU_LoadGatherOp
839//===----------------------------------------------------------------------===//
840LogicalResult LoadGatherOp::verify() {
841 auto tdescTy = getTensorDescType();
842 auto maskTy = getMaskType();
843 auto valueTy = getValueType();
844
845 if (!tdescTy && !getOffsets())
846 return emitOpError("Expects offsets.");
847
848 if (tdescTy && getOffsets())
849 return emitOpError("offsets not allowed.");
850
851 if (tdescTy && !tdescTy.isScattered())
852 return emitOpError("Expects a scattered TensorDesc.");
853
854 if (!isReadHintOrNone(getL1HintAttr()))
855 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
856
857 if (!isReadHintOrNone(getL2HintAttr()))
858 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
859
860 if (!isReadHintOrNone(getL3HintAttr()))
861 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
862
863 if (tdescTy)
864 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
865 [&]() { return emitOpError(); });
866 auto srcTy = getSourceType();
867 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
868 auto memTy = dyn_cast<MemRefType>(srcTy);
869
870 if (memTy && (getElementType() != memTy.getElementType()))
871 return emitError() << "Value should have the same element type as MemRef.";
872
873 auto offsetsTy = getOffsets().getType();
874 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
875 [&]() { return emitOpError(); });
876}
877
878void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
879 Type valueType, Value source, Value mask,
880 xegpu::CachePolicyAttr l1_hint,
881 xegpu::CachePolicyAttr l2_hint,
882 xegpu::CachePolicyAttr l3_hint) {
883 build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
884 l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
885}
886
887void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
888 Type valueType, Value source,
889 ArrayRef<OpFoldResult> offsets, Value mask,
890 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
891 xegpu::CachePolicyAttr l2_hint,
892 xegpu::CachePolicyAttr l3_hint) {
893 auto loc = source.getLoc();
894 int64_t size = static_cast<int64_t>(offsets.size());
895 auto type = VectorType::get(size, builder.getIndexType());
896 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
897 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
898
899 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
900 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
901}
902
903void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
904 Type valueType, Value source,
905 ArrayRef<OpFoldResult> offsets, Value mask,
906 IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
907 xegpu::CachePolicyAttr l2_hint,
908 xegpu::CachePolicyAttr l3_hint,
909 DistributeLayoutAttr layout) {
910 auto loc = source.getLoc();
911 int64_t size = static_cast<int64_t>(offsets.size());
912 auto type = VectorType::get(size, builder.getIndexType());
913 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
914 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
915
916 build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
917 l2_hint, l3_hint, layout);
918}
919
920//===----------------------------------------------------------------------===//
921// XeGPU_StoreScatterOp
922//===----------------------------------------------------------------------===//
923LogicalResult StoreScatterOp::verify() {
924 auto tdescTy = getTensorDescType();
925 auto maskTy = getMaskType();
926 auto valueTy = getValueType();
927
928 if (!tdescTy && !getOffsets())
929 return emitOpError("Expects offsets.");
930
931 if (tdescTy && getOffsets())
932 return emitOpError("offsets not allowed.");
933
934 if (tdescTy && !tdescTy.isScattered())
935 return emitOpError("Expects a scattered TensorDesc.");
936
937 if (!isWriteHintOrNone(getL1HintAttr()))
938 return emitOpError("invalid l1_hint: ") << getL1HintAttr();
939
940 if (!isWriteHintOrNone(getL2HintAttr()))
941 return emitOpError("invalid l2_hint: ") << getL2HintAttr();
942
943 if (!isWriteHintOrNone(getL3HintAttr()))
944 return emitOpError("invalid l3_hint: ") << getL3HintAttr();
945
946 if (tdescTy)
947 return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
948 [&]() { return emitOpError(); });
949
950 auto destTy = getDestType();
951 uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
952 auto memTy = dyn_cast<MemRefType>(destTy);
953
954 if (memTy && (getElementType() != memTy.getElementType()))
955 return emitError() << "Value should have the same element type as MemRef.";
956
957 auto offsetsTy = getOffsets().getType();
958 return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
959 [&]() { return emitOpError(); });
960}
961
962void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
963 Value value, Value dest, Value mask,
964 xegpu::CachePolicyAttr l1_hint,
965 xegpu::CachePolicyAttr l2_hint,
966 xegpu::CachePolicyAttr l3_hint) {
967 build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
968 l2_hint, l3_hint, /*anchor_layout=*/nullptr);
969}
970
971void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
972 Value value, Value dest,
973 ArrayRef<OpFoldResult> offsets, Value mask,
974 IntegerAttr chunk_size,
975 xegpu::CachePolicyAttr l1_hint,
976 xegpu::CachePolicyAttr l2_hint,
977 xegpu::CachePolicyAttr l3_hint) {
978 auto loc = dest.getLoc();
979 int64_t size = static_cast<int64_t>(offsets.size());
980 auto type = VectorType::get(size, builder.getIndexType());
981 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
982 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
983
984 // Call the correct builder overload that does not expect result types.
985 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
986 l3_hint, /*anchor_layout=*/nullptr);
987}
988
989void StoreScatterOp::build(
990 OpBuilder &builder, OperationState &state, Value value, Value dest,
991 ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
992 xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
993 xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
994 auto loc = dest.getLoc();
995 int64_t size = static_cast<int64_t>(offsets.size());
996 auto type = VectorType::get(size, builder.getIndexType());
997 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
998 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
999
1000 // Call the correct builder overload that does not expect result types.
1001 build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
1002 l3_hint, layout);
1003}
1004
1005//===----------------------------------------------------------------------===//
1006// XeGPU_UpdateOffsetOp
1007//===----------------------------------------------------------------------===//
1008void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1009 mlir::Value tensorDesc,
1011 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
1012 assert(tdescTy && "Expecting the source is a TensorDescType value.");
1013 auto loc = tensorDesc.getLoc();
1014 int64_t size = static_cast<int64_t>(offsets.size());
1015 auto type = VectorType::get({size}, builder.getIndexType());
1016 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
1017 auto offset = vector::FromElementsOp::create(builder, loc, type, values);
1018 build(builder, state, tdescTy, tensorDesc, offset);
1019}
1020
1021void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
1022 Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
1023 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
1024 build(builder, state, tensorDesc, ofrs);
1025}
1026
1027LogicalResult UpdateOffsetOp::verify() {
1028 auto tdescTy = getTensorDescType();
1029 if (!tdescTy.isScattered())
1030 return emitOpError("Expects a scattered TensorDesc.\n");
1031
1032 SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
1033 SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
1034 if (tdescTy.getChunkSizeAsInt() > 1)
1035 expectedOffsetShape.pop_back();
1036
1037 if (expectedOffsetShape != offsetShape)
1038 return emitOpError(
1039 "Offsets should match TensorDesc except the chunk size dim.");
1040
1041 return success();
1042}
1043
1044//===----------------------------------------------------------------------===//
1045// XeGPU_DpasOp
1046//===----------------------------------------------------------------------===//
1047LogicalResult DpasOp::verify() {
1048 int64_t lhsRank = getLhsType().getRank();
1049 int64_t rhsRank = getRhsType().getRank();
1050 int64_t resRank = getResultType().getRank();
1051 auto lhsShape = getLhsType().getShape();
1052 auto rhsShape = getRhsType().getShape();
1053 auto resShape = getResultType().getShape();
1054
1055 if (getAcc() && getAcc().getType() != getResultType())
1056 return emitOpError("Expecting the acc type to be the same as result.");
1057
1058 // SIMT code: the size of the B operand has to be a multiple of 32 bits.
1059 // It skips the semantic check since lack of architecture information.
1060 // Users need to ensure the correctness.
1061 if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
1062 auto numElems = getRhsType().getNumElements();
1063 auto elemTy = getRhsType().getElementType();
1064 auto factor = 32 / elemTy.getIntOrFloatBitWidth();
1065 if (numElems % factor != 0)
1066 return emitOpError("Expecting B operand to be a multiple of 32 bits.");
1067 return success();
1068 }
1069
1070 // SIMD code
1071 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
1072 return emitOpError(
1073 "expecting lhs and result to be a 2D vector, and rhs to be either "
1074 "2D or 3D (packed) vector.");
1075 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
1076 if (bK != lhsShape[1])
1077 return emitOpError("K-dimension mismatch.");
1078 if (lhsShape[0] != resShape[0])
1079 return emitOpError("M-dimension mismatch.");
1080 if (rhsShape[1] != resShape[1])
1081 return emitOpError("N-dimension mismatch.");
1082
1083 return success();
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// XeGPU_ConvertLayoutOp
1088//===----------------------------------------------------------------------===//
1089LogicalResult ConvertLayoutOp::verify() {
1090 auto srcLayout = getInputLayout();
1091 auto resLayout = getTargetLayout();
1092 if (!srcLayout)
1093 return emitOpError("expected input layout.");
1094 if (!resLayout)
1095 return emitOpError("expected target layout.");
1096
1097 // both input and target layouts should be WgLayout or SgLayout at the same
1098 // time.
1099 if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
1100 (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
1101 return emitOpError("expected input layout and target layout be WgLayout or "
1102 "SgLayout at the same time.");
1103
1104 auto shape = getSource().getType().getShape();
1105 if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
1106 return emitOpError(
1107 "invalid input layout, data cannot be evenly distributed.");
1108
1109 if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
1110 return emitOpError(
1111 "invalid target layout, data cannot be evenly distributed.");
1112
1113 return mlir::success();
1114}
1115
1116//===----------------------------------------------------------------------===//
1117// XeGPU_LoadMatrixOp
1118//===----------------------------------------------------------------------===//
1119void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
1122 DistributeLayoutAttr layout) {
1123 llvm::SmallVector<Value> dynamicOffsets;
1124 llvm::SmallVector<int64_t> staticOffsets;
1125 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1126 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1127 // Call the generated builder with all parameters (including optional ones as
1128 // nullptr/empty)
1129 build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1130 /*subgroup_block_io=*/nullptr, layout);
1131}
1132
1133LogicalResult LoadMatrixOp::verify() {
1134
1135 auto resTy = dyn_cast<VectorType>(getRes().getType());
1136 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1137 MemDescType mdescTy = getMemDesc().getType();
1138
1139 return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
1140 getLayoutAttr(), [&]() { return emitError(); });
1141}
1142
1143//===----------------------------------------------------------------------===//
1144// XeGPU_StoreMatrixOp
1145//===----------------------------------------------------------------------===//
1146void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
1149 DistributeLayoutAttr layout) {
1150 llvm::SmallVector<Value> dynamicOffsets;
1151 llvm::SmallVector<int64_t> staticOffsets;
1152 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1153 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1154 build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1155 /*subgroup_block_io=*/nullptr, layout);
1156}
1157
1158LogicalResult StoreMatrixOp::verify() {
1159
1160 auto dataTy = dyn_cast<VectorType>(getData().getType());
1161 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1162 MemDescType mdescTy = getMemDesc().getType();
1163 return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
1164 getLayoutAttr(), [&]() { return emitError(); });
1165}
1166
1167//===----------------------------------------------------------------------===//
1168// XeGPU_TruncfOp
1169//===----------------------------------------------------------------------===//
1170
1171LogicalResult TruncfOp::verify() {
1172 auto sourceVecType = dyn_cast<VectorType>(getSource().getType());
1173 auto resultVecType = dyn_cast<VectorType>(getResult().getType());
1174
1175 if (sourceVecType.getElementTypeBitWidth() <=
1176 resultVecType.getElementTypeBitWidth())
1177 return emitOpError("input type must be wider than result type.");
1178
1179 return success();
1180}
1181
1182//===----------------------------------------------------------------------===//
1183// XeGPU_DpasMxOp
1184//===----------------------------------------------------------------------===//
1185
1186LogicalResult DpasMxOp::verify() {
1187 if (getAcc() && getAcc().getType() != getResultType())
1188 return emitOpError("Expecting the acc type to be the same as result.");
1189
1190 return success();
1191}
1192
1193namespace mlir {
1194#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
1195} // namespace mlir
1196#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
1197#define GET_OP_CLASSES
1198#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
ArrayAttr()
static Type getValueType(Attribute attr)
Definition SPIRVOps.cpp:773
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< int64_t > getShapeOf(Type type)
Definition XeGPUOps.cpp:52
LogicalResult IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io, DistributeLayoutAttr layout, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:177
static std::string makeString(T array, bool breakline=false)
Definition XeGPUOps.cpp:38
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:69
static bool isReadHintOrNone(const CachePolicyAttr &attr)
Definition XeGPUOps.cpp:61
static LogicalResult isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, VectorType valueTy, int64_t chunkSize, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:125
static void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers)
Definition XeGPUOps.cpp:450
static bool isSharedMemory(const MemRefType &memrefTy)
Definition XeGPUOps.cpp:26
static ParseResult parseOptionalDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Definition XeGPUOps.cpp:412
static LogicalResult isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy, function_ref< InFlightDiagnostic()> emitError)
Definition XeGPUOps.cpp:78
Delimiter
These are the supported delimiters around operand lists and region argument lists,...
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Attributes are known-constant values of operations.
Definition Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This class represents a diagnostic that is inflight and set to be reported.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:209
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.