15 #include "llvm/Support/Debug.h"
17 #define DEBUG_TYPE "xegpu"
25 for (
size_t i = 0; i < trans.size(); i++)
26 shape[i] = old[trans[i]];
30 static std::string
makeString(T array,
bool breakline =
false) {
33 llvm::raw_string_ostream os(buf);
35 for (
size_t i = 1; i < array.size(); i++) {
36 os << array[i - 1] <<
", ";
40 os << array.back() <<
"]";
46 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
55 if (
auto ty = llvm::dyn_cast<ShapedType>(type))
63 auto kind = attr.getValue();
64 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
71 auto kind = attr.getValue();
72 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
85 if (descShape == valShape && !sgMap)
93 size_t descRank = descShape.size();
94 if (descRank > 2 || valShape.size() != descRank)
102 mapLayout = {wiLayout.back()};
104 for (
const auto &[factor, dim, expected] :
105 llvm::zip_equal(mapLayout, valShape, descShape)) {
106 if (factor * dim != expected)
119 [[maybe_unused]]
auto ty = source.getType();
120 assert(ty.hasStaticShape() && offsets.size() == (
size_t)ty.getRank());
126 build(builder, state, tdesc, source, dynamicOffsets ,
133 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
134 Type tdesc, TypedValue<MemRefType> source,
138 assert(shape.size() && offsets.size() && strides.size() &&
139 shape.size() == strides.size() && shape.size() == offsets.size());
152 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
153 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
154 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
156 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
157 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
160 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
161 Type tdesc, TypedValue<IntegerType> source,
165 assert(shape.size() && offsets.size() && strides.size() &&
166 shape.size() == strides.size() && shape.size() == offsets.size());
179 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
180 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
181 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
183 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
184 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
188 auto rank = (int64_t)getMixedOffsets().size();
189 bool invalidRank =
false;
190 bool invalidElemTy =
false;
196 auto srcMemorySpace = getSourceMemorySpace();
197 auto tdescMemorySpace =
static_cast<unsigned>(
getType().getMemorySpace());
198 if (srcMemorySpace != tdescMemorySpace)
199 return emitOpError(
"Memory space mismatch.")
200 <<
" Source: " << srcMemorySpace
201 <<
", TensorDesc: " << tdescMemorySpace;
205 auto memrefTy = dyn_cast<MemRefType>(getSourceType());
207 invalidRank |= (memrefTy.getRank() != rank);
216 "Expecting the rank of shape, strides, offsets, and source (if source "
217 "is a memref) should match with each other.");
220 invalidRank = (
getType().getRank() > 2 ||
getType().getRank() > rank);
224 "Expecting the TensorDesc rank is up to 2 and not greater than the "
225 "ranks of shape, strides, offsets or the memref source.");
228 return emitOpError(
"TensorDesc should have the same element "
229 "type with the source if it is a memref.\n");
232 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
241 auto tdescTy = getTensorDescType();
242 if (tdescTy.isScattered())
243 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
246 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
249 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
252 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
261 auto tdescTy = getTensorDescType();
264 if (tdescTy.getRank() > 2)
265 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
267 if (tdescTy.isScattered())
268 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
271 return emitOpError(
"Invalid result, it should be a VectorType.\n");
274 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
277 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
280 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
282 auto array_len = tdescTy.getArrayLength();
286 if (getTranspose()) {
287 auto trans = getTranspose().value();
290 bool valid = std::all_of(trans.begin(), trans.end(), [&](
int t) {
291 return t >= 0 && t < tdescTy.getRank();
301 if (tdescTy.getRank() == 2) {
303 auto vnni_factor = valueShape.back();
304 tdescShape[axis] /= vnni_factor;
305 tdescShape.push_back(vnni_factor);
308 <<
"Invalid Packed Attr. It is ignored (available for 2D "
314 auto it = tdescShape.begin();
315 tdescShape.insert(it, array_len);
317 auto sgMap = tdescTy.getSGMapAttr();
320 return emitOpError() <<
"Result shape doesn't match TensorDesc shape."
321 <<
"The expected shape is " <<
makeString(tdescShape)
322 <<
". But the given shape is "
331 auto dstTy = getTensorDescType();
332 auto valTy = getValueType();
334 if (dstTy.getRank() > 2)
335 return emitOpError(
"Expecting a 1D/2D TensorDesc.\n");
337 if (dstTy.isScattered())
338 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
341 return emitOpError(
"Expecting a VectorType result.\n");
344 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
347 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
350 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
354 auto sgMap = dstTy.getSGMapAttr();
357 return emitOpError() <<
"Result shape doesn't match TensorDesc shape."
358 <<
"The expected shape is " <<
makeString(tdescShape)
359 <<
". But the given shape is "
368 auto ty = getTensorDescType();
369 if (ty.isScattered())
370 return emitOpError(
"Expects a non-scattered TensorDesc.\n");
373 if (ty.getRank() != (int64_t)getNumOffsets()) {
374 return emitOpError(
"Invalid number of offsets.");
383 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
384 TensorDescType TensorDesc, Value source,
386 auto loc = source.getLoc();
387 int64_t size =
static_cast<int64_t
>(offsets.size());
390 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
391 build(builder, state, TensorDesc, source, offset);
394 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
395 TensorDescType TensorDesc, Value source,
398 build(builder, state, TensorDesc, source, ofrs);
402 auto tdescTy = getTensorDescType();
406 "Expecting the source is a 1D memref or pointer (uint64_t).");
408 if (!tdescTy.isScattered())
409 return emitOpError(
"Expects a scattered TensorDesc.\n");
415 auto srcMemorySpace = getSourceMemorySpace();
416 auto tdescMemorySpace =
static_cast<unsigned>(tdescTy.getMemorySpace());
417 if (srcMemorySpace != tdescMemorySpace)
418 return emitOpError(
"Memory space mismatch.")
419 <<
" Source: " << srcMemorySpace
420 <<
", TensorDesc: " << tdescMemorySpace;
422 auto chunkSize = tdescTy.getChunkSize();
426 16, 32, 64, 128, 256};
427 if (!llvm::is_contained(supportedChunkSizes, chunkSize))
428 return emitOpError(
"Invalid chunk_size. Supported values are 1, 2, 3, 4, "
429 "8, 16, 32, 64, 128, or 256.");
432 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
433 auto bitsPerLane = elemBits * chunkSize;
434 if (chunkSize > 1 && bitsPerLane % 32) {
441 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
444 auto lscConstraints = 512 * 8;
445 if (elemBits * tdescTy.getNumElements() > lscConstraints)
446 return emitOpError(
"total access size (simd_lanes * chunk_size * "
447 "sizeof(elemTy)) is upto 512 bytes.");
449 SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
451 shape.push_back(chunkSize);
454 if (shape != tdescShape)
455 return emitOpError(
"Incorrect TensorDesc shape. ")
456 <<
"Expected is " <<
makeString(shape) <<
"\n";
465 auto tdescTy = getTensorDescType();
466 if (!tdescTy.isScattered())
467 return emitOpError(
"Expects a scattered TensorDesc.\n");
470 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
473 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
476 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
485 auto tdescTy = getTensorDescType();
486 auto maskTy = getMaskType();
487 auto valueTy = getValueType();
489 if (!tdescTy.isScattered())
490 return emitOpError(
"Expects a scattered TensorDesc.\n");
493 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
496 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
499 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
501 auto tdescElemTy = tdescTy.getElementType();
503 if (tdescElemTy != valueElemTy)
505 "Value should have the same element type as TensorDesc.");
511 if (tdescShape[0] != maskShape[0])
512 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
514 if (tdescTy.getRank() == 2) {
515 if (!getTransposeAttr())
516 return emitOpError(
"load of rank-2 tensor has to be transposed.");
520 if (
auto sgMap = tdescTy.getSGMapAttr()) {
521 auto valueVecTy = cast<VectorType>(valueTy);
522 const int32_t wiData =
523 sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
525 if (valueVecTy.getNumElements() != wiData ||
526 valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
527 return emitOpError(
"Chunk size, vector size and wi_data must match.");
530 tdescShape[tdescTy.getRank() - 1] = 1;
533 if (valueShape != tdescShape)
534 return emitOpError(
"Unexpected result shape")
535 <<
"(Expected shape: " <<
makeString(tdescShape)
536 <<
", Given shape: " <<
makeString(valueShape) <<
").\n";
545 auto tdescTy = getTensorDescType();
546 if (!tdescTy.isScattered())
547 return emitOpError(
"Expects a scattered TensorDesc.\n");
550 return emitOpError(
"invalid l1_hint: ") << getL1HintAttr();
553 return emitOpError(
"invalid l2_hint: ") << getL2HintAttr();
556 return emitOpError(
"invalid l3_hint: ") << getL3HintAttr();
558 auto maskTy = getMaskType();
559 auto valueTy = getValueType();
563 if (tdescShape[0] != maskShape[0])
564 return emitOpError(
"dim-0 of the Mask and TensorDesc should be the same.");
566 if (tdescTy.getRank() == 2) {
567 if (!getTransposeAttr())
568 return emitOpError(
"Store of a rank-2 tensor has to be transposed.");
572 if (
auto sgMap = tdescTy.getSGMapAttr()) {
573 auto valueVecTy = cast<VectorType>(valueTy);
574 const int32_t wiData =
575 sgMap.getWiData()[0] > 1 ? sgMap.getWiData()[0] : sgMap.getWiData()[1];
577 if (valueVecTy.getNumElements() != wiData ||
578 valueVecTy.getNumElements() != tdescTy.getChunkSize()) {
579 return emitOpError(
"Chunk size, vector size and wi_data must match.");
582 tdescShape[tdescTy.getRank() - 1] = 1;
585 if (valueShape != tdescShape)
586 return emitOpError(
"Unexpected value shape")
587 <<
"(Expected shape: " <<
makeString(tdescShape)
588 <<
", Given shape: " <<
makeString(valueShape) <<
").\n";
596 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
599 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.
getType());
600 assert(tdescTy &&
"Expecting the source is a TensorDescType value.");
601 auto loc = tensorDesc.
getLoc();
602 int64_t size =
static_cast<int64_t
>(offsets.size());
605 auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
606 build(builder, state, tdescTy, tensorDesc, offset);
609 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
612 build(builder, state, tensorDesc, ofrs);
619 int64_t lhsRank = getLhsType().getRank();
620 int64_t rhsRank = getRhsType().getRank();
622 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
623 return emitOpError(
"expecting lhs to be a 2D vector, and rhs to be either "
624 "2D or 3D (packed) vector.");
626 auto lhsShape = getLhsType().getShape();
627 auto rhsShape = getRhsType().getShape();
628 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
629 if (bK != lhsShape[1])
630 return emitOpError(
"K-dimension mismatch.");
638 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
639 #define GET_OP_CLASSES
640 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
static std::string makeString(T array, bool breakline=false)
static int64_t getRankOf(Value val)
static bool isWriteHintOrNone(const CachePolicyAttr &attr)
static bool isReadHintOrNone(const CachePolicyAttr &attr)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
static bool isArgShapesValid(ArrayRef< int64_t > descShape, ArrayRef< int64_t > valShape, SGMapAttr sgMap)
static SmallVector< int64_t > getShapeOf(Type type)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.